Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fix #184

Merged
merged 7 commits into from
Sep 7, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
modify contextual mab param
  • Loading branch information
maypink committed Aug 15, 2023
commit e1a7790065404f28bda401238c560c263c540684
19 changes: 14 additions & 5 deletions golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from functools import partial
from typing import Union, Sequence, Optional, List
from typing import Union, Sequence, Optional, List, Callable

import numpy as np
from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy
Expand All @@ -14,10 +14,18 @@

class ContextualMultiArmedBanditAgent(OperatorAgent):
""" Contextual Multi-Armed bandit. Observations can be encoded with simple context agent without
using NN to guarantee convergence. """
using NN to guarantee convergence.

:param actions: types of mutations
:param context_agent: function to convert observation to its embedding. Can be specified as
ContextAgentTypeEnum or as Callable function.
:param available_operations: available operations
:param n_jobs: n_jobs
:param enable_logging: bool logging flag
"""

def __init__(self, actions: Sequence[ActType],
context_agent_type: ContextAgentTypeEnum,
context_agent: Union[ContextAgentTypeEnum, Callable],
available_operations: List[str],
n_jobs: int = 1,
enable_logging: bool = True):
Expand All @@ -29,8 +37,9 @@ def __init__(self, actions: Sequence[ActType],
learning_policy=LearningPolicy.UCB1(alpha=1.25),
neighborhood_policy=NeighborhoodPolicy.Clusters(),
n_jobs=n_jobs)
self._context_agent = partial(ContextAgentsRepository.agent_class_by_id(context_agent_type),
available_operations=available_operations)
self._context_agent = context_agent if isinstance(context_agent, Callable) else \
partial(ContextAgentsRepository.agent_class_by_id(context_agent),
available_operations=available_operations)
self._is_fitted = False

def _initial_fit(self, obs: ObsType):
Expand Down