Skip to content

Commit

Permalink
fix imports (aimclub#119)
Browse files Browse the repository at this point in the history
* fix importa

* replace deprecated & fix legenda

* minor

* fix imports

* minor

* minor
  • Loading branch information
maypink authored Jun 14, 2023
1 parent 072f027 commit 1a52230
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 14 deletions.
1 change: 0 additions & 1 deletion .github/workflows/unit-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
pip install .[docs]
pip install .[profilers]
pip install pytest-cov
pip install -r requirements_adaptive.txt
- name: Test with pytest
run: |
pytest --cov=golem test/unit
Expand Down
2 changes: 1 addition & 1 deletion examples/molecule_search/guacamol_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def generate_optimized_molecules(self, scoring_function: ScoringFunction, number
# Take only the first graph's appearance in history
individuals \
= list({hash(self.graph_gen_params.adapter.restore(ind.graph)): ind
for gen in history.individuals
for gen in history.generations
for ind in reversed(list(gen))}.values())

top_individuals = sorted(individuals,
Expand Down
10 changes: 5 additions & 5 deletions experiments/mab/mab_synthetic_experiment_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def show_action_probabilities(self, bandit_type: MutationAgentTypeEnum, stats_ac
else:
centers = sorted(self.cluster.cluster_centers_)
for i in range(self.cluster.n_clusters):
titles = [title + f' for cluster with center {int(centers[i])}' for title in titles]
titles_centers = [title + f' for cluster with center {int(centers[i])}' for title in titles]
plot_action_values(stats=stats_action_value_log[i], action_tags=actions,
titles=titles)
titles=titles_centers)
plt.show()

def show_average_action_probabilities(self, show_action_probabilities: dict, actions):
Expand Down Expand Up @@ -193,7 +193,7 @@ def initial_population_func(graph_size: List[int] = None, pop_size: int = None,
launch_num = 1
target_size = 50

bandits_to_compare = [MutationAgentTypeEnum.contextual_bandit]
bandits_to_compare = [MutationAgentTypeEnum.contextual_bandit, MutationAgentTypeEnum.bandit]
setup_parameters_func = partial(setup_parameters, target_size=target_size, trial_timeout=timeout)
initial_population_func = partial(initial_population_func,
graph_size=[random.randint(5, 10) for _ in range(10)] +
Expand All @@ -204,5 +204,5 @@ def initial_population_func(graph_size: List[int] = None, pop_size: int = None,
n_clusters=2, is_visualize=True)
helper.compare_bandits(initial_population_func=initial_population_func,
setup_parameters=setup_parameters_func)
# helper.show_boxplots()
# helper.show_fitness_lines()
helper.show_boxplots()
helper.show_fitness_lines()
8 changes: 7 additions & 1 deletion golem/core/optimisers/adaptive/neural_mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
import math
from typing import List, Any, Union, Dict

import torch
from golem.utilities.requirements_notificator import warn_requirement

try:
import torch
except ModuleNotFoundError:
warn_requirement('torch', 'other_requirements/requirements_adaptive.txt')

import numpy as np
from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy
from mabwiser.utils import Arm, Constants, Num
Expand Down
4 changes: 2 additions & 2 deletions golem/visualisation/opt_history/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def plot_diversity_dynamic_gif(history: 'OptHistory',
metric_names = history.objective.metric_names
# dtype=float removes None, puts np.nan
# indexed by [population, metric, individual] after transpose (.T)
pops = history.individuals[1:-1] # ignore initial pop and final choices
pops = history.generations[1:-1] # ignore initial pop and final choices
fitness_distrib = [np.array([ind.fitness.values for ind in pop], dtype=float).T
for pop in pops]

Expand Down Expand Up @@ -95,7 +95,7 @@ def update_axes(iframe: int):
def plot_diversity_dynamic(history: 'OptHistory',
show: bool = True, save_path: Optional[str] = None, dpi: int = 100):
labels = history.objective.metric_names
h = history.individuals[:-1] # don't consider final choices
h = history.generations[:-1] # don't consider final choices
xs = np.arange(len(h))

# Compute diversity by metrics
Expand Down
2 changes: 1 addition & 1 deletion golem/visualisation/opt_history/fitness_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def plot_average_fitness_line_per_generations(axis: plt.Axes, histories, label:

fitness_value_per_generation = []
for history in histories:
generations = history.individuals
generations = history.generations
for gen_num, gen in enumerate(generations):
for ind in gen:
if ind.native_generation != gen_num:
Expand Down
4 changes: 2 additions & 2 deletions golem/visualisation/opt_viz_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pareto_gif_create(self,
objectives_names: Tuple[str] = ('ROC-AUC', 'Complexity')):
files = []
pareto_fronts = self.history.archive_history
individuals = self.history.individuals
individuals = self.history.generations
array_for_analysis = individuals if individuals else pareto_fronts
all_objectives = extract_objectives(array_for_analysis, objectives_numbers)
min_x, max_x = min(all_objectives[0]) - 0.01, max(all_objectives[0]) + 0.01
Expand Down Expand Up @@ -173,7 +173,7 @@ def _create_boxplot(self, individuals: List[Any], generation_num: int = None,
plt.savefig(path, bbox_inches='tight')

def boxplots_gif_create(self, objectives_names: Tuple[str] = ('ROC-AUC', 'Complexity')):
individuals = self.history.individuals
individuals = self.history.generations
objectives = extract_objectives(individuals)
objectives = list(itertools.chain(*objectives))
min_y, max_y = min(objectives), max(objectives)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion test/unit/optimizers/test_composing_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _test_individuals_in_history(history: OptHistory):
uids = set()
ids = set()
for ind in itertools.chain(*history.generations):
# All individuals in `history.individuals` must have a native generation.
# All individuals in `history.generations` must have a native generation.
assert ind.has_native_generation
assert ind.fitness
if ind.native_generation == 0:
Expand Down

0 comments on commit 1a52230

Please sign in to comment.