Skip to content

Commit

Permalink
Fix integation test (v4) (aimclub#143)
Browse files Browse the repository at this point in the history
* Fix integration test error (invalid node types for search)

* Fix test: wrong metric was optimized
  • Loading branch information
gkirgizov authored Jul 14, 2023
1 parent f4e6d2f commit f64661f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
2 changes: 2 additions & 0 deletions golem/metrics/edit_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def _label_dist(label1: str, label2: str) -> int:


def tree_edit_dist(target_graph: nx.DiGraph, graph: nx.DiGraph) -> float:
"""Compares nodes by their `name` (if present) or `uid` attribute.
Nodes with the same name/id are considered the same."""
target_tree_root = _nx_to_zss_tree(target_graph)
cmp_tree_root = _nx_to_zss_tree(graph)
dist = zss.simple_distance(target_tree_root, cmp_tree_root, label_dist=_label_dist)
Expand Down
30 changes: 19 additions & 11 deletions test/integration/test_structure_search.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
from datetime import timedelta
from functools import partial
from math import ceil

import numpy as np
import pytest
from typing import Tuple, Callable
from typing import Tuple, Callable, Sequence

from examples.synthetic_graph_evolution.experiment_setup import run_trial
from examples.synthetic_graph_evolution.generators import generate_labeled_graph
from examples.synthetic_graph_evolution.tree_search import tree_search_setup
from golem.core.adapter.nx_adapter import BaseNetworkxAdapter
from golem.core.dag.graph import Graph
from golem.core.optimisers.objective import Objective
from golem.metrics.edit_distance import tree_edit_dist, graph_size


def run_search(size: int, distance_function: Callable, timeout_min: int = 1) -> Tuple[float, Graph]:
target_graph = generate_labeled_graph('tree', size, node_labels=['x'])
# defining task
node_types = ['a', 'b']
target_graph = generate_labeled_graph('tree', size, node_labels=node_types)
objective = Objective(partial(distance_function, target_graph))

# running the example
found_graph, history = run_trial(target_graph=target_graph,
optimizer_setup=tree_search_setup,
timeout=timedelta(minutes=timeout_min))
optimizer, objective = tree_search_setup(objective=objective,
timeout=timedelta(minutes=timeout_min),
node_types=node_types)
found_graphs = optimizer.optimise(objective)
found_graph = found_graphs[0] if isinstance(found_graphs, Sequence) else found_graphs

found_nx_graph = BaseNetworkxAdapter().restore(found_graph)
distance = distance_function(target_graph, found_nx_graph)
# compute final distance. it accepts nx graphs, so first adapt it to accept our graphs
adapted_dist = BaseNetworkxAdapter().adapt_func(distance_function)
distance = adapted_dist(target_graph, found_graph)

return distance, found_graph


@pytest.mark.parametrize('target_sizes, distance_function, indulgence',
[([10, 24], tree_edit_dist, 0.5),
([10, 24], graph_size, 0.2)])
([30], graph_size, 0.1)])
def test_simple_targets_are_found(target_sizes, distance_function, indulgence):
""" Checks if simple targets can be found within specified time. """
for target_size in target_sizes:
num_trials = 5
num_trials = 3
distances = []
for i in range(num_trials):
distance, target_graph = run_search(target_size, distance_function=distance_function, timeout_min=2)
distance, target_graph = run_search(target_size, distance_function=distance_function, timeout_min=1)
distances.append(distance)

assert target_graph is not None
Expand Down

0 comments on commit f64661f

Please sign in to comment.