diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index f33b1193..477750c4 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -139,7 +139,7 @@ def wrap_dataset(self, dataset): def set_decode_type_multistart(self, phase: str): """Set decode type to `multistart` for train, val and test in policy. - For example, if the decode type is `greedy`, it will be set to `greedy_multistart`. + For example, if the decode type is `greedy`, it will be set to `multistart_greedy`. Args: phase: Phase to set decode type for. Must be one of `train`, `val` or `test`. @@ -148,12 +148,12 @@ def set_decode_type_multistart(self, phase: str): attr_get = getattr(self.policy, attribute) # If does not exist, log error if attr_get is None: - log.error(f"Decode type for {phase} is None. Cannot add `_multistart`.") + log.error(f"Decode type for {phase} is None. Cannot prepend `multistart_`.") return elif "multistart" in attr_get: return else: - setattr(self.policy, attribute, f"{attr_get}_multistart") + setattr(self.policy, attribute, f"multistart_{attr_get}") @classmethod def load_from_checkpoint( diff --git a/rl4co/models/zoo/eas/decoder.py b/rl4co/models/zoo/eas/decoder.py index 9e646632..0e7c9acf 100644 --- a/rl4co/models/zoo/eas/decoder.py +++ b/rl4co/models/zoo/eas/decoder.py @@ -57,7 +57,7 @@ def forward_eas( best_solutions, iter_count: int = 0, env: Union[str, RL4COEnvBase] = None, - decode_type: str = "sampling_multistart", + decode_type: str = "multistart_sampling", num_starts: int = None, softmax_temp: float = None, **unused_kwargs, diff --git a/rl4co/tasks/eval.py b/rl4co/tasks/eval.py index bb4ff368..1000cca1 100644 --- a/rl4co/tasks/eval.py +++ b/rl4co/tasks/eval.py @@ -194,7 +194,7 @@ class GreedyMultiStartEval(EvalBase): num_starts (int): Number of greedy multistarts to use """ - name = "greedy_multistart" + name = "multistart_greedy" def __init__(self, env, num_starts=None, **kwargs): check_unused_kwargs(self, kwargs) @@ -207,7 +207,7 @@ def _inner(self, policy, td): td_init = td.clone() out = policy( td.clone(), - decode_type="greedy_multistart", + decode_type="multistart_greedy", num_starts=self.num_starts, return_actions=True, ) @@ -236,7 +236,7 @@ class GreedyMultiStartAugmentEval(EvalBase): force_dihedral_8: If True, force the use of 8 augmentations (rotations and flips) as in POMO """ - name = "greedy_multistart_augment" + name = "multistart_greedy_augment" def __init__( self, env, num_starts=None, num_augment=8, force_dihedral_8=False, **kwargs @@ -262,7 +262,7 @@ def _inner(self, policy, td, num_augment=None): td = self.augmentation(td) out = policy( td.clone(), - decode_type="greedy_multistart", + decode_type="multistart_greedy", num_starts=self.num_starts, return_actions=True, ) @@ -335,7 +335,7 @@ def evaluate_policy( "func": SamplingEval, "kwargs": {"samples": 100, "softmax_temp": 1.0}, }, - "greedy_multistart": { + "multistart_greedy": { "func": GreedyMultiStartEval, "kwargs": {"num_starts": num_loc}, }, @@ -344,7 +344,7 @@ def evaluate_policy( "kwargs": {"num_augment": 8, "force_dihedral_8": True}, }, "augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}}, - "greedy_multistart_augment_dihedral_8": { + "multistart_greedy_augment_dihedral_8": { "func": GreedyMultiStartAugmentEval, "kwargs": { "num_augment": 8, @@ -352,7 +352,7 @@ def evaluate_policy( "num_starts": num_loc, }, }, - "greedy_multistart_augment": { + "multistart_greedy_augment": { "func": GreedyMultiStartAugmentEval, "kwargs": {"num_augment": 8, "num_starts": num_loc}, }, diff --git a/tests/test_policy.py b/tests/test_policy.py index 13e74780..f387b19a 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -25,7 +25,7 @@ def test_base_policy_multistart(env_name, size=20, batch_size=2): td = env.reset(x) policy = AutoregressivePolicy(env.name) num_starts = size // 2 if env.name in ["pdp"] else size - out = policy(td, env, decode_type="greedy_multistart", num_starts=num_starts) + out = policy(td, env, decode_type="multistart_greedy", num_starts=num_starts) assert out["reward"].shape == ( batch_size * num_starts, ) # to evaluate, we could just unbatchify