Skip to content

Commit

Permalink
[Refactor] change decode_type_multistart to multistart_decode_type
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Feb 26, 2024
1 parent 20269a9 commit 1d48733
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions rl4co/models/rl/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def wrap_dataset(self, dataset):

def set_decode_type_multistart(self, phase: str):
"""Set decode type to `multistart` for train, val and test in policy.
For example, if the decode type is `greedy`, it will be set to `greedy_multistart`.
For example, if the decode type is `greedy`, it will be set to `multistart_greedy`.
Args:
phase: Phase to set decode type for. Must be one of `train`, `val` or `test`.
Expand All @@ -148,12 +148,12 @@ def set_decode_type_multistart(self, phase: str):
attr_get = getattr(self.policy, attribute)
# If does not exist, log error
if attr_get is None:
log.error(f"Decode type for {phase} is None. Cannot add `_multistart`.")
log.error(f"Decode type for {phase} is None. Cannot prepend `multistart_`.")
return
elif "multistart" in attr_get:
return
else:
setattr(self.policy, attribute, f"{attr_get}_multistart")
setattr(self.policy, attribute, f"multistart_{attr_get}")

@classmethod
def load_from_checkpoint(
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/eas/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward_eas(
best_solutions,
iter_count: int = 0,
env: Union[str, RL4COEnvBase] = None,
decode_type: str = "sampling_multistart",
decode_type: str = "multistart_sampling",
num_starts: int = None,
softmax_temp: float = None,
**unused_kwargs,
Expand Down
14 changes: 7 additions & 7 deletions rl4co/tasks/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class GreedyMultiStartEval(EvalBase):
num_starts (int): Number of greedy multistarts to use
"""

name = "greedy_multistart"
name = "multistart_greedy"

def __init__(self, env, num_starts=None, **kwargs):
check_unused_kwargs(self, kwargs)
Expand All @@ -207,7 +207,7 @@ def _inner(self, policy, td):
td_init = td.clone()
out = policy(
td.clone(),
decode_type="greedy_multistart",
decode_type="multistart_greedy",
num_starts=self.num_starts,
return_actions=True,
)
Expand Down Expand Up @@ -236,7 +236,7 @@ class GreedyMultiStartAugmentEval(EvalBase):
force_dihedral_8: If True, force the use of 8 augmentations (rotations and flips) as in POMO
"""

name = "greedy_multistart_augment"
name = "multistart_greedy_augment"

def __init__(
self, env, num_starts=None, num_augment=8, force_dihedral_8=False, **kwargs
Expand All @@ -262,7 +262,7 @@ def _inner(self, policy, td, num_augment=None):
td = self.augmentation(td)
out = policy(
td.clone(),
decode_type="greedy_multistart",
decode_type="multistart_greedy",
num_starts=self.num_starts,
return_actions=True,
)
Expand Down Expand Up @@ -335,7 +335,7 @@ def evaluate_policy(
"func": SamplingEval,
"kwargs": {"samples": 100, "softmax_temp": 1.0},
},
"greedy_multistart": {
"multistart_greedy": {
"func": GreedyMultiStartEval,
"kwargs": {"num_starts": num_loc},
},
Expand All @@ -344,15 +344,15 @@ def evaluate_policy(
"kwargs": {"num_augment": 8, "force_dihedral_8": True},
},
"augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}},
"greedy_multistart_augment_dihedral_8": {
"multistart_greedy_augment_dihedral_8": {
"func": GreedyMultiStartAugmentEval,
"kwargs": {
"num_augment": 8,
"force_dihedral_8": True,
"num_starts": num_loc,
},
},
"greedy_multistart_augment": {
"multistart_greedy_augment": {
"func": GreedyMultiStartAugmentEval,
"kwargs": {"num_augment": 8, "num_starts": num_loc},
},
Expand Down
2 changes: 1 addition & 1 deletion tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_base_policy_multistart(env_name, size=20, batch_size=2):
td = env.reset(x)
policy = AutoregressivePolicy(env.name)
num_starts = size // 2 if env.name in ["pdp"] else size
out = policy(td, env, decode_type="greedy_multistart", num_starts=num_starts)
out = policy(td, env, decode_type="multistart_greedy", num_starts=num_starts)
assert out["reward"].shape == (
batch_size * num_starts,
) # to evaluate, we could just unbatchify
Expand Down

0 comments on commit 1d48733

Please sign in to comment.