Skip to content

Commit

Permalink
add serialization (aimclub#73)
Browse files Browse the repository at this point in the history
* fix

* fix#2

* minor

* add serialization

* add tests

* minor

* docstrings

* minor

* minors

* remove try-except blocks

* minor

* fix pep8

* fix pep8
  • Loading branch information
maypink authored Apr 10, 2023
1 parent b5b9878 commit cfeb735
Show file tree
Hide file tree
Showing 16 changed files with 329 additions and 142 deletions.
17 changes: 11 additions & 6 deletions examples/structural_analysis/opt_graph_optimization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
from functools import partial

Expand All @@ -9,7 +10,8 @@
from golem.core.optimisers.graph import OptGraph, OptNode
from golem.core.optimisers.objective import Objective
from golem.core.optimisers.opt_node_factory import DefaultOptNodeFactory
from golem.metrics.graph_metrics import degree_dist, size_diff
from golem.core.paths import project_root
from golem.metrics.graph_metrics import degree_distance, size_diff
from golem.structural_analysis.graph_sa.graph_structural_analysis import GraphStructuralAnalysis
from golem.structural_analysis.graph_sa.sa_requirements import StructuralAnalysisRequirements

Expand All @@ -27,13 +29,13 @@ def get_opt_graph() -> OptGraph:

def quality_custom_metric_1(_: OptGraph) -> float:
""" Get toy metric for demonstration. """
metric = -1*random.randint(80, 100)/100
metric = -1 * random.randint(80, 100) / 100
return metric


def quality_custom_metric_2(_: OptGraph) -> float:
""" Get one more toy metric for demonstration. """
metric = -1*random.randint(70, 110)/100
metric = -1 * random.randint(70, 110) / 100
return metric


Expand All @@ -55,7 +57,7 @@ def complexity_metric(graph: OptGraph, adapter: BaseNetworkxAdapter, metric: Cal
},
complexity_metrics={
'degree': partial(complexity_metric,
adapter=adapter, metric=partial(degree_dist, adapter.restore(opt_graph))),
adapter=adapter, metric=partial(degree_distance, adapter.restore(opt_graph))),
'graph_size': partial(complexity_metric,
adapter=adapter, metric=partial(size_diff, adapter.restore(opt_graph))),
},
Expand All @@ -72,5 +74,8 @@ def complexity_metric(graph: OptGraph, adapter: BaseNetworkxAdapter, metric: Cal
sa = GraphStructuralAnalysis(objective=objective, node_factory=node_factory,
requirements=requirements)

optimized_graph = sa.optimize(graph=opt_graph, n_jobs=1, max_iter=5)
optimized_graph.show()
graph, results = sa.optimize(graph=opt_graph, n_jobs=1, max_iter=2)
graph.show()

path_to_save = os.path.join(project_root(), 'sa_results.json')
results.save(path=path_to_save, datetime_in_path=False)
46 changes: 24 additions & 22 deletions golem/structural_analysis/graph_sa/edge_sa_approaches.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import random
from abc import ABC, abstractmethod
from abc import ABC
from copy import deepcopy
from os import makedirs
from os.path import exists, join
from typing import List, Optional, Type, Union, Tuple, Dict, Callable, Sequence
from typing import List, Optional, Type, Union, Dict, Callable

from golem.core.dag.graph_verifier import GraphVerifier
from golem.core.log import default_log
from golem.core.dag.graph import Graph, GraphNode
from golem.core.dag.graph import Graph
from golem.core.optimisers.objective import Objective
from golem.core.optimisers.timer import OptimisationTimer
from golem.core.paths import default_data_dir
Expand Down Expand Up @@ -58,16 +57,18 @@ def analyze(self, graph: Graph, edge: Edge,
:return: dict with Edge analysis result per approach
"""

results = ObjectSAResult(entity=edge)
results = ObjectSAResult(entity_idx=
f'{graph.nodes.index(edge.parent_node)}_{graph.nodes.index(edge.child_node)}',
entity_type='edge')

for approach in self.approaches:
if timer is not None and timer.is_time_limit_reached():
break

results.add_result(approach(graph=graph,
objective=objective,
requirements=self.approaches_requirements,
path_to_save=self.path_to_save).analyze(edge=edge))
objective=objective,
requirements=self.approaches_requirements,
path_to_save=self.path_to_save).analyze(edge=edge))

return results

Expand Down Expand Up @@ -114,7 +115,7 @@ def analyze(self, edge: Edge, **kwargs) -> DeletionSAApproachResult:
results = DeletionSAApproachResult()
if edge.child_node is self._graph.root_node and len(self._graph.root_node.nodes_from) == 1:
self.log.warning('if remove this edge then get a graph of length one')
results.add_results(metrics_values=[-1.0]*len(self._objective.metrics))
results.add_results(metrics_values=[-1.0] * len(self._objective.metrics))
return results
else:
shortened_graph = self.sample(edge)
Expand All @@ -124,7 +125,7 @@ def analyze(self, edge: Edge, **kwargs) -> DeletionSAApproachResult:
del shortened_graph
else:
self.log.warning('if remove this edge then get an invalid graph')
losses = [-1.0]*len(self._objective.metrics)
losses = [-1.0] * len(self._objective.metrics)

results.add_results(metrics_values=losses)
return results
Expand Down Expand Up @@ -207,10 +208,8 @@ def analyze(self, edge: Edge, **kwargs) -> ReplaceSAApproachResult:
parent_node_idx += char
else:
continue
result.add_results(entity_to_replace_to=Edge(child_node=self._graph.nodes[int(child_node_idx)],
parent_node=self._graph.nodes[int(parent_node_idx)]),
result.add_results(entity_to_replace_to=f'{parent_node_idx}_{child_node_idx}',
metrics_values=loss_per_sample)

return result

def sample(self, edge: Edge,
Expand Down Expand Up @@ -245,17 +244,18 @@ def sample(self, edge: Edge,
previous_parent_node = sample_graph.nodes[previous_parent_node_index]
previous_child_node = sample_graph.nodes[previous_child_node_index]

sample_graph.disconnect_nodes(node_parent=previous_parent_node,
node_child=previous_child_node,
clean_up_leftovers=False)
# connect nodes
next_parent_node = sample_graph.nodes[replacing_nodes_idx['parent_node_idx']]
next_child_node = sample_graph.nodes[replacing_nodes_idx['child_node_idx']]

if next_parent_node in sample_graph.nodes and \
next_child_node in sample_graph.nodes:
next_child_node in sample_graph.nodes:
sample_graph.connect_nodes(next_parent_node, next_child_node)

sample_graph.disconnect_nodes(node_parent=previous_parent_node,
node_child=previous_child_node,
clean_up_leftovers=False)

verifier = self._requirements.graph_verifier
if not verifier.verify(sample_graph):
self.log.warning('Can not connect these nodes')
Expand All @@ -264,14 +264,14 @@ def sample(self, edge: Edge,
self.log.message(f'replace edge child: {next_child_node}')
samples.append(sample_graph)
edges_nodes_idx_to_replace_to.append({'parent_node_id':
replacing_nodes_idx['parent_node_idx'],
replacing_nodes_idx['parent_node_idx'],
'child_node_id':
replacing_nodes_idx['child_node_idx']})
replacing_nodes_idx['child_node_idx']})

if not edges_nodes_idx_to_replace_to:
res = {'samples': [self._graph], 'edges_nodes_idx_to_replace_to':
[{'parent_node_id': self._graph.nodes.index(edge.parent_node),
'child_node_id': self._graph.nodes.index(edge.child_node)}]}
[{'parent_node_id': self._graph.nodes.index(edge.parent_node),
'child_node_id': self._graph.nodes.index(edge.child_node)}]}
return res

return {'samples': samples, 'edges_nodes_idx_to_replace_to': edges_nodes_idx_to_replace_to}
Expand Down Expand Up @@ -312,9 +312,11 @@ def _edge_generation(self, edge: Edge, number_of_operations: int = 1) -> List[Di
continue
if [parent_node, child_node] in edges_in_graph or [child_node, parent_node] in edges_in_graph:
continue
if cur_graph.nodes.index(parent_node) == child_node_index and \
cur_graph.nodes.index(child_node) == parent_node_index:
continue
available_edges_idx.append({'parent_node_idx': cur_graph.nodes.index(parent_node),
'child_node_idx': cur_graph.nodes.index(child_node)})

# random.seed(self._requirements.seed + len(self._graph))
edges_for_replacement = random.sample(available_edges_idx, min(number_of_operations, len(available_edges_idx)))
return edges_for_replacement
7 changes: 3 additions & 4 deletions golem/structural_analysis/graph_sa/edges_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os.path import join
from typing import Optional, List, Type, Callable
from typing import Optional, List, Type

import multiprocessing

Expand Down Expand Up @@ -55,7 +55,7 @@ def analyze(self, graph: Graph, results: Optional[SAAnalysisResults] = None,
"""

if not results:
results = SAAnalysisResults(graph=graph)
results = SAAnalysisResults()

if n_jobs == -1:
n_jobs = multiprocessing.cpu_count()
Expand All @@ -72,7 +72,6 @@ def analyze(self, graph: Graph, results: Optional[SAAnalysisResults] = None,
cur_edges_result = pool.starmap(edge_analysis.analyze,
[[graph, edge, self.objective, timer]
for edge in edges_to_analyze])
for res in cur_edges_result:
results.add_edge_result(res)
results.add_results(cur_edges_result)

return results
55 changes: 35 additions & 20 deletions golem/structural_analysis/graph_sa/graph_structural_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Tuple
import multiprocessing

from golem.core.log import default_log
Expand Down Expand Up @@ -74,20 +74,23 @@ def __init__(self, objective: Objective,
self.path_to_save = path_to_save

def analyze(self, graph: Graph,
result: SAAnalysisResults = None,
nodes_to_analyze: List[GraphNode] = None, edges_to_analyze: List[Edge] = None,
n_jobs: int = 1, timer: OptimisationTimer = None) -> SAAnalysisResults:
"""
Applies defined structural analysis approaches
:param graph: graph object to analyze
:param result: analysis result
:param nodes_to_analyze: nodes to analyze. Default: all nodes
:param edges_to_analyze: edges to analyze. Default: all edges
:param n_jobs: num of ``n_jobs`` for parallelization (``-1`` for use all cpu's).
Tip: if specified graph isn't huge (as NN, for example) than set n_jobs to default value.
:param timer: timer with timeout left for optimization
"""

result = SAAnalysisResults(graph=graph)
if not result:
result = SAAnalysisResults()

if n_jobs == -1:
n_jobs = multiprocessing.cpu_count()
Expand All @@ -110,37 +113,48 @@ def analyze(self, graph: Graph,
return result

def optimize(self, graph: Graph,
analysis_result: Optional[SAAnalysisResults] = None,
n_jobs: int = 1, timer: OptimisationTimer = None,
max_iter: int = 10) -> Graph:
max_iter: int = 10) -> Tuple[Graph, SAAnalysisResults]:
""" Optimizes graph by applying 'analyze' method and deleting/replacing parts
of graph iteratively
:param graph: graph object to analyze
:param graph: graph object to analyze.
:param analysis_result: if graph was already analyzed than analysis results could be applied.
:param n_jobs: num of ``n_jobs`` for parallelization (``-1`` for use all cpu's).
Tip: if specified graph isn't huge (as NN, for example) than set n_jobs to default value.
:param timer: timer with timeout left for optimization.
:param max_iter: max number of iterations of analysis. """

if analysis_result:
optimized_graph = self.apply_results(graph=graph, analysis_result=analysis_result)
return optimized_graph, analysis_result

approaches_repo = StructuralAnalysisApproachesRepository()
approaches = self._nodes_analyze.approaches + self._edges_analyze.approaches
approaches_names = [approach.__name__ for approach in approaches]

# what actions were applied on the graph and how many
actions_applied = dict.fromkeys(approaches_names, 0)

result = SAAnalysisResults()

graph_before_sa = deepcopy(graph)
analysis_results = self.analyze(graph=graph, timer=timer, n_jobs=n_jobs)
analysis_result = self.analyze(graph=graph, result=result, timer=timer, n_jobs=n_jobs)
if self.path_to_save:
_save_iteration_results(graph_before_sa=graph_before_sa,
save_path=self.path_to_save)
converged = False
iter = 0

if analysis_results.is_empty:
self._log.message(f'0 actions were taken during SA')
return graph
if analysis_result.is_empty:
self._log.message('0 actions were taken during SA')
return graph, analysis_result

while not converged:
iter += 1
worst_result = analysis_results.get_info_about_worst_result(
worst_result = analysis_result.get_info_about_worst_result(
metric_idx_to_optimize_by=self.main_metric_idx)
if worst_result['value'] > 1.0:
if worst_result['value'] > 1.2:
# apply the worst approach
postproc_method = approaches_repo.postproc_method_by_name(worst_result['approach_name'])
graph = postproc_method(graph=graph, worst_result=worst_result)
Expand All @@ -152,25 +166,27 @@ def optimize(self, graph: Graph,
if max_iter and iter >= max_iter:
break

analysis_results = self.analyze(graph=graph, n_jobs=n_jobs,
timer=timer)
analysis_result = self.analyze(graph=graph,
result=result,
n_jobs=n_jobs,
timer=timer)
if self.path_to_save:
_save_iteration_results(graph_before_sa=graph_before_sa,
save_path=self.path_to_save)
else:
converged = True

if self.path_to_save:
_save_iteration_results(graph_before_sa=graph_before_sa, save_path=self.path_to_save)

self._log.message(f'{iter} iterations passed during SA')
self._log.message(f'The following actions were applied during SA: {actions_applied}')

if isinstance(graph, Graph):
return graph
return graph, analysis_result
else:
return graph_before_sa
return graph_before_sa, analysis_result

@staticmethod
def apply_results(graph: Graph, analysis_result: Optional[dict] = None) -> Graph:
""" Optimizes graph by applying actions specified in analysis_result """
def apply_results(graph: Graph, analysis_result: SAAnalysisResults) -> Graph:
""" Optimizes graph by applying actions specified in analysis_result. """
pass

@staticmethod
Expand All @@ -194,7 +210,6 @@ def graph_preprocessing(graph: Graph):

def _save_iteration_results(graph_before_sa: Graph, save_path: str = None):
""" Save visualizations for SA per iteration """
json_path = os.path.join(save_path, 'results_per_iteration.json')
graph_save_path = os.path.join(save_path, 'result_graphs')
graph_before_sa.save(graph_save_path)
if not os.path.exists(graph_save_path):
Expand Down
Loading

0 comments on commit cfeb735

Please sign in to comment.