Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and Load MABs #185

Merged
merged 31 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ff02744
fix
maypink Mar 17, 2023
987cce0
fix#2
maypink Mar 17, 2023
022363c
minor
maypink Mar 21, 2023
d24dfa2
Merge branch 'main' of https://github.com/aimclub/GOLEM into 66-singl…
maypink Mar 27, 2023
4900e5f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Mar 29, 2023
cc8729f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 4, 2023
f320cfa
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 10, 2023
f3ca604
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 21, 2023
b76b1c3
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink May 3, 2023
13a76bf
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 1, 2023
df115e3
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 9, 2023
4342573
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 13, 2023
56db3a7
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 15, 2023
5826890
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 16, 2023
33339ef
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 21, 2023
1863a1a
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jul 3, 2023
da4eaad
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jul 6, 2023
0959a02
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Aug 7, 2023
9543642
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 4, 2023
a9885f4
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 6, 2023
5d65d7d
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 12, 2023
5834a0f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Sep 12, 2023
1a0e20f
save and load bandits
maypink Sep 7, 2023
3b39650
minor
maypink Sep 7, 2023
46241de
add tests
maypink Sep 8, 2023
3b9daa0
зуз8
maypink Sep 8, 2023
c6a1bfc
enhance path match
maypink Sep 11, 2023
8d664d2
modify saving
maypink Sep 11, 2023
72b8249
fixes after review
maypink Sep 11, 2023
b18f159
fixes after review #2
maypink Sep 11, 2023
23d640e
add comment
maypink Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests
  • Loading branch information
maypink committed Sep 12, 2023
commit 46241dedc8ae5bcf5cde9a397a22f8c6d4e2fa5c
20 changes: 12 additions & 8 deletions golem/core/optimisers/adaptive/mab_agents/mab_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import pickle
import random
from pathlib import Path
from typing import Union, Sequence, Optional

from mabwiser.mab import MAB, LearningPolicy
Expand Down Expand Up @@ -58,18 +59,21 @@ def partial_fit(self, experience: ExperienceBuffer):

def save(self):
""" Saves bandit to specified file. """
path_to_save = os.path.join(self._path_to_save, 'MAB')
os.makedirs(path_to_save, exist_ok=True)
# to get file name
mabs_num = [int(e.split('_')[0]) for e in os.listdir(path_to_save) if e.split('_')[0].isdigit()]
if not mabs_num:
max_saved_mab = 0
if not self._path_to_save.endswith('.pkl'):
path_to_save = os.path.join(self._path_to_save, 'MAB')
os.makedirs(path_to_save, exist_ok=True)
mabs_num = [int(e.split('_')[0]) for e in os.listdir(path_to_save) if e.split('_')[0].isdigit()]
if not mabs_num:
max_saved_mab = 0
else:
max_saved_mab = max(mabs_num) + 1
path_to_file = os.path.join(path_to_save, f'{max_saved_mab}_mab.pkl')
else:
max_saved_mab = max(mabs_num) + 1
path_to_file = os.path.join(path_to_save, f'{max_saved_mab}_mab.pkl')
path_to_file = self._path_to_save
with open(path_to_file, 'wb') as f:
pickle.dump(self, f)
YamLyubov marked this conversation as resolved.
Show resolved Hide resolved
self._log.info(f"MAB was save to {self._path_to_save}")
self._log.info(f"MAB was saved to {self._path_to_save}")
YamLyubov marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def load(path: str):
Expand Down
24 changes: 24 additions & 0 deletions test/unit/adaptive/test_mab_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os.path
from pathlib import Path

from golem.core.optimisers.adaptive.mab_agents.mab_agent import MultiArmedBanditAgent


def test_save_mab():
""" Tests if MAB is saved. """
file_name = 'test_mab.pkl'
path_to_save = os.path.join(Path(__file__).parent, file_name)
mab = MultiArmedBanditAgent(actions=[0, 1, 2],
n_jobs=1,
path_to_save=path_to_save)
mab.save()
assert file_name in os.listdir(Path(__file__).parent)


def test_load_mab():
""" Tests if MAB is loaded. """
file_name = 'test_mab.pkl'
path_to_load = os.path.join(Path(__file__).parent, file_name)
mab = MultiArmedBanditAgent.load(path=path_to_load)
assert isinstance(mab, MultiArmedBanditAgent)
YamLyubov marked this conversation as resolved.
Show resolved Hide resolved
os.remove(path_to_load)