Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions ge/models/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,10 @@ def _gen_sampling_table(self):
self.node_accept, self.node_alias = create_alias_table(norm_prob)

# create sampling table for edge
numEdges = self.graph.number_of_edges()
total_sum = sum([self.graph[edge[0]][edge[1]].get('weight', 1.0)
for edge in self.graph.edges()])
norm_prob = [self.graph[edge[0]][edge[1]].get('weight', 1.0) *
numEdges / total_sum for edge in self.graph.edges()]
norm_prob = [self.graph[edge[0]][edge[1]].get('weight', 1.0) /
total_sum for edge in self.graph.edges()]

self.edge_accept, self.edge_alias = create_alias_table(norm_prob)

Expand Down
32 changes: 32 additions & 0 deletions tests/line_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,42 @@

pytest.importorskip("tensorflow")
from ge import LINE
from ge.models import line as line_module
from ge.utils import preprocess_nxgraph

TEST_GRAPH_PATH = Path(__file__).resolve().parent / "Wiki_edgelist.txt"


def test_LINE_sampling_tables_use_normalized_probabilities(monkeypatch):
graph = nx.DiGraph()
graph.add_edge("a", "b", weight=2)
graph.add_edge("a", "c", weight=6)

model = LINE.__new__(LINE)
model.graph = graph
model.idx2node, model.node2idx = preprocess_nxgraph(graph)
model.node_size = graph.number_of_nodes()

sampling_tables = []

def record_sampling_table(area_ratio):
sampling_tables.append(list(area_ratio))
return [1] * len(area_ratio), [0] * len(area_ratio)

monkeypatch.setattr(line_module, "create_alias_table", record_sampling_table)

LINE._gen_sampling_table(model)

node_probs, edge_probs = sampling_tables
expected_edge_probs = [
graph[edge[0]][edge[1]]["weight"] / 8 for edge in graph.edges()
]

assert sum(node_probs) == pytest.approx(1)
assert sum(edge_probs) == pytest.approx(1)
assert edge_probs == pytest.approx(expected_edge_probs)


def test_LINE():
graph = nx.read_edgelist(
str(TEST_GRAPH_PATH),
Expand Down
Loading