Skip to content

Commit

Permalink
[BugFix] Fix QValueModule multi_one_hot (pytorch#1439)
Browse files Browse the repository at this point in the history
  • Loading branch information
smorad authored Aug 1, 2023
1 parent bbf5545 commit 71fd4c2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
14 changes: 14 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,20 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5):
spec=action_spec,
)

@pytest.mark.parametrize(
"action_space, var_nums, expected_action",
(
("multi_one_hot", [2, 2, 2], [1, 0, 1, 0, 1, 0]),
("multi_one_hot", [2, 4], [1, 0, 1, 0, 0, 0]),
),
)
def test_qvalue_module_multi_one_hot(self, action_space, var_nums, expected_action):
module = QValueModule(action_space=action_space, var_nums=var_nums)
in_values = torch.tensor([1.0, 0, 2, 0, 1, 0])
action, values, chosen_action_value = module(in_values)
assert (torch.tensor(expected_action, dtype=torch.long) == action).all()
assert (values == in_values).all()

@pytest.mark.parametrize(
"action_space, expected_action",
(
Expand Down
10 changes: 8 additions & 2 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,17 @@ def _one_hot(value: torch.Tensor) -> torch.Tensor:
def _categorical(value: torch.Tensor) -> torch.Tensor:
return torch.argmax(value, dim=-1).to(torch.long)

def _mult_one_hot(self, value: torch.Tensor, support: torch.Tensor) -> torch.Tensor:
def _mult_one_hot(
self, value: torch.Tensor, support: torch.Tensor = None
) -> torch.Tensor:
if self.var_nums is None:
raise ValueError(
"var_nums must be provided to the constructor for multi one-hot action spaces."
)
values = value.split(self.var_nums, dim=-1)
return torch.cat(
[
QValueHook._one_hot(
self._one_hot(
_value,
)
for _value in values
Expand Down

0 comments on commit 71fd4c2

Please sign in to comment.