Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and Load MABs #185

Merged
merged 31 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ff02744
fix
maypink Mar 17, 2023
987cce0
fix#2
maypink Mar 17, 2023
022363c
minor
maypink Mar 21, 2023
d24dfa2
Merge branch 'main' of https://github.com/aimclub/GOLEM into 66-singl…
maypink Mar 27, 2023
4900e5f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Mar 29, 2023
cc8729f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 4, 2023
f320cfa
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 10, 2023
f3ca604
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 21, 2023
b76b1c3
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink May 3, 2023
13a76bf
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 1, 2023
df115e3
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 9, 2023
4342573
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 13, 2023
56db3a7
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 15, 2023
5826890
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 16, 2023
33339ef
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 21, 2023
1863a1a
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jul 3, 2023
da4eaad
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jul 6, 2023
0959a02
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Aug 7, 2023
9543642
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 4, 2023
a9885f4
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 6, 2023
5d65d7d
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 12, 2023
5834a0f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 12, 2023
1a0e20f
save and load bandits
maypink Sep 7, 2023
3b39650
minor
maypink Sep 7, 2023
46241de
add tests
maypink Sep 8, 2023
3b9daa0
зуз8
maypink Sep 8, 2023
c6a1bfc
enhance path match
maypink Sep 11, 2023
8d664d2
modify saving
maypink Sep 11, 2023
72b8249
fixes after review
maypink Sep 11, 2023
b18f159
fixes after review #2
maypink Sep 11, 2023
23d640e
add comment
maypink Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions experiments/mab/mab_synthetic_experiment_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def log_action_values(next_pop: PopulationT, optimizer: EvoGraphOptimizer):
if 0 not in stats_action_value_log.keys():
stats_action_value_log[0] = []
stats_action_value_log[0].append(list(values))
# MAB agent can be saved here -- commented not to clog up the memory
# optimizer.mutation.agent.save()

def log_action_values_with_clusters(next_pop: PopulationT, optimizer: EvoGraphOptimizer):
obs_contexts = optimizer.mutation.agent.get_context(next_pop)
Expand Down
42 changes: 41 additions & 1 deletion golem/core/optimisers/adaptive/mab_agents/mab_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os.path
import _pickle as pickle
import random
import re
from typing import Union, Sequence, Optional

from mabwiser.mab import MAB, LearningPolicy
Expand All @@ -7,13 +10,15 @@
from golem.core.dag.graph import Graph
from golem.core.dag.graph_node import GraphNode
from golem.core.optimisers.adaptive.operator_agent import OperatorAgent, ActType, ObsType, ExperienceBuffer
from golem.core.paths import default_data_dir


class MultiArmedBanditAgent(OperatorAgent):
def __init__(self,
actions: Sequence[ActType],
n_jobs: int = 1,
enable_logging: bool = True):
enable_logging: bool = True,
path_to_save: Optional[str] = None):
super().__init__(enable_logging)
self.actions = list(actions)
self._indices = list(range(len(actions)))
Expand All @@ -22,6 +27,7 @@ def __init__(self,
learning_policy=LearningPolicy.UCB1(alpha=1.25),
n_jobs=n_jobs)
self._initial_fit()
self._path_to_save = path_to_save

def _initial_fit(self):
n = len(self.actions)
Expand Down Expand Up @@ -51,3 +57,37 @@ def partial_fit(self, experience: ExperienceBuffer):
self._dbg_log(obs, actions, rewards)
arms = [self._arm_by_action[action] for action in actions]
self._agent.partial_fit(decisions=arms, rewards=rewards)

def save(self, path_to_save: Optional[str] = None):
""" Saves bandit to specified file. """

if not path_to_save:
path_to_save = self._path_to_save

# if path was not specified at all
if not path_to_save:
YamLyubov marked this conversation as resolved.
Show resolved Hide resolved
path_to_save = os.path.join(default_data_dir(), 'MAB')

if not path_to_save.endswith('.pkl'):
os.makedirs(path_to_save, exist_ok=True)
mabs_num = [int(name.split('_')[0]) for name in os.listdir(path_to_save)
if re.fullmatch(r'\d_mab.pkl', name)]
if not mabs_num:
max_saved_mab = 0
else:
max_saved_mab = max(mabs_num) + 1
path_to_file = os.path.join(path_to_save, f'{max_saved_mab}_mab.pkl')
else:
path_to_dir = os.path.dirname(path_to_save)
os.makedirs(path_to_dir, exist_ok=True)
path_to_file = path_to_save
with open(path_to_file, 'wb') as f:
pickle.dump(self, f)
YamLyubov marked this conversation as resolved.
Show resolved Hide resolved
self._log.info(f"MAB was saved to {path_to_file}")

@staticmethod
def load(path: str):
""" Loads bandit from the specified file. """
with open(path, 'rb') as f:
mab = pickle.load(f)
return mab
3 changes: 2 additions & 1 deletion golem/core/optimisers/genetic/operators/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def _init_operator_agent(graph_gen_params: GraphGenerationParams,
agent = RandomAgent(actions=parameters.mutation_types)
elif kind == MutationAgentTypeEnum.bandit:
agent = MultiArmedBanditAgent(actions=parameters.mutation_types,
n_jobs=requirements.n_jobs)
n_jobs=requirements.n_jobs,
path_to_save=requirements.agent_dir)
elif kind == MutationAgentTypeEnum.contextual_bandit:
agent = ContextualMultiArmedBanditAgent(
actions=parameters.mutation_types,
Expand Down
1 change: 1 addition & 0 deletions golem/core/optimisers/optimization_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class OptimizationParameters:

keep_history: bool = True
history_dir: Optional[str] = field(default_factory=default_data_dir)
agent_dir: Optional[str] = field(default_factory=default_data_dir)


@dataclass
Expand Down
42 changes: 42 additions & 0 deletions test/unit/adaptive/test_mab_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os.path
from pathlib import Path

import pytest
from mabwiser.mab import MAB

from golem.core.optimisers.adaptive.mab_agents.mab_agent import MultiArmedBanditAgent


@pytest.mark.parametrize('path_to_save, file_name',
[(os.path.join(Path(__file__).parent, 'test_mab.pkl'), 'test_mab.pkl'),
(os.path.join(Path(__file__).parent), '0_mab.pkl')])
def test_save_mab(path_to_save, file_name):
""" Tests if MAB is saved with specifying file_nam and without. """
mab = MultiArmedBanditAgent(actions=[0, 1, 2],
n_jobs=1,
path_to_save=path_to_save)
mab.save()
assert file_name in os.listdir(Path(__file__).parent)
os.remove(path_to_save if path_to_save.endswith('pkl') else os.path.join(path_to_save, file_name))


def test_load_mab():
""" Tests if MAB is loaded. """
file_name = 'test_mab.pkl'
path_to_load = os.path.join(Path(__file__).parent, file_name)
# save mab to load it later
mab = MultiArmedBanditAgent(actions=[0, 1, 2],
n_jobs=1,
path_to_save=path_to_load)
mab.save()

loaded_mab = MultiArmedBanditAgent.load(path=path_to_load)
assert isinstance(loaded_mab, MultiArmedBanditAgent)

assert isinstance(loaded_mab._agent, MAB)
assert loaded_mab.__eq__(mab)
assert loaded_mab.actions == mab.actions
assert loaded_mab._enable_logging == mab._enable_logging
assert loaded_mab._path_to_save == mab._path_to_save

os.remove(path_to_load)