diff --git a/ge/models/line.py b/ge/models/line.py index 6bae314..ac423f0 100644 --- a/ge/models/line.py +++ b/ge/models/line.py @@ -131,11 +131,10 @@ def _gen_sampling_table(self): self.node_accept, self.node_alias = create_alias_table(norm_prob) # create sampling table for edge - numEdges = self.graph.number_of_edges() total_sum = sum([self.graph[edge[0]][edge[1]].get('weight', 1.0) for edge in self.graph.edges()]) - norm_prob = [self.graph[edge[0]][edge[1]].get('weight', 1.0) * - numEdges / total_sum for edge in self.graph.edges()] + norm_prob = [self.graph[edge[0]][edge[1]].get('weight', 1.0) / + total_sum for edge in self.graph.edges()] self.edge_accept, self.edge_alias = create_alias_table(norm_prob) diff --git a/tests/line_test.py b/tests/line_test.py index 832320d..f5af775 100644 --- a/tests/line_test.py +++ b/tests/line_test.py @@ -5,10 +5,42 @@ pytest.importorskip("tensorflow") from ge import LINE +from ge.models import line as line_module +from ge.utils import preprocess_nxgraph TEST_GRAPH_PATH = Path(__file__).resolve().parent / "Wiki_edgelist.txt" +def test_LINE_sampling_tables_use_normalized_probabilities(monkeypatch): + graph = nx.DiGraph() + graph.add_edge("a", "b", weight=2) + graph.add_edge("a", "c", weight=6) + + model = LINE.__new__(LINE) + model.graph = graph + model.idx2node, model.node2idx = preprocess_nxgraph(graph) + model.node_size = graph.number_of_nodes() + + sampling_tables = [] + + def record_sampling_table(area_ratio): + sampling_tables.append(list(area_ratio)) + return [1] * len(area_ratio), [0] * len(area_ratio) + + monkeypatch.setattr(line_module, "create_alias_table", record_sampling_table) + + LINE._gen_sampling_table(model) + + node_probs, edge_probs = sampling_tables + expected_edge_probs = [ + graph[edge[0]][edge[1]]["weight"] / 8 for edge in graph.edges() + ] + + assert sum(node_probs) == pytest.approx(1) + assert sum(edge_probs) == pytest.approx(1) + assert edge_probs == pytest.approx(expected_edge_probs) + + def test_LINE(): graph = nx.read_edgelist( str(TEST_GRAPH_PATH),