Skip to content

Commit

Permalink
[Refactor] Refactor DQN (pytorch#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 25, 2023
1 parent ce88a95 commit 32339da
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 59 deletions.
85 changes: 80 additions & 5 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.data import DiscreteTensorSpec, OneHotDiscreteTensorSpec
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.modules import MLP, SafeModule
from torchrl.modules.tensordict_module.actors import (
_process_action_space_spec,
ActorValueOperator,
DistributionalQValueActor,
DistributionalQValueHook,
Expand All @@ -29,14 +35,44 @@

class TestQValue:
def test_qvalue_hook_wrong_action_space(self):
with pytest.raises(ValueError) as exc:
with pytest.raises(
ValueError, match="action_space was not specified/not compatible"
):
QValueHook(action_space="wrong_value")
assert "action_space must be one of" in str(exc.value)

def test_distributional_qvalue_hook_wrong_action_space(self):
with pytest.raises(ValueError) as exc:
with pytest.raises(
ValueError, match="action_space was not specified/not compatible"
):
DistributionalQValueHook(action_space="wrong_value", support=None)
assert "action_space must be one of" in str(exc.value)

def test_distributional_qvalue_hook_conflicting_spec(self):
spec = OneHotDiscreteTensorSpec(3)
_process_action_space_spec("one-hot", spec)
_process_action_space_spec("one_hot", spec)
_process_action_space_spec("one_hot", None)
_process_action_space_spec(None, spec)
with pytest.raises(
ValueError, match="The action spec and the action space do not match"
):
_process_action_space_spec("multi-one-hot", spec)
spec = MultiOneHotDiscreteTensorSpec([3, 3])
_process_action_space_spec("multi-one-hot", spec)
_process_action_space_spec(spec, spec)
with pytest.raises(
ValueError, match="Passing an action_space as a TensorSpec and a spec"
):
_process_action_space_spec(OneHotDiscreteTensorSpec(3), spec)
with pytest.raises(
ValueError, match="action_space cannot be of type CompositeSpec"
):
_process_action_space_spec(CompositeSpec(), spec)
with pytest.raises(KeyError, match="action could not be found in the spec"):
_process_action_space_spec(None, CompositeSpec())
with pytest.raises(
ValueError, match="Neither action_space nor spec was defined"
):
_process_action_space_spec(None, None)

@pytest.mark.parametrize(
"action_space, expected_action",
Expand Down Expand Up @@ -406,6 +442,45 @@ def make_net():
assert (action.sum(-1) == 1).all()


@pytest.mark.parametrize(
"spec", [None, OneHotDiscreteTensorSpec(3), MultiOneHotDiscreteTensorSpec([3, 2])]
)
@pytest.mark.parametrize(
"action_space", [None, "one-hot", "one_hot", "mult-one-hot", "mult_one_hot"]
)
def test_qvalactor_construct(
spec,
action_space,
):
kwargs = {}
if spec is not None:
kwargs["spec"] = spec
if action_space is not None:
kwargs["action_space"] = action_space
kwargs["module"] = TensorDictModule(
lambda x: x, in_keys=["x"], out_keys=["action_value"]
)
if spec is None and action_space is None:
with pytest.raises(
ValueError, match="Neither action_space nor spec was defined"
):
QValueActor(**kwargs)
return
if (
type(spec) is MultiOneHotDiscreteTensorSpec
and action_space not in ("mult-one-hot", "mult_one_hot", None)
) or (
type(spec) is OneHotDiscreteTensorSpec
and action_space not in ("one-hot", "one_hot", None)
):
with pytest.raises(
ValueError, match="The action spec and the action space do not match"
):
QValueActor(**kwargs)
return
QValueActor(**kwargs)


@pytest.mark.parametrize("device", get_available_devices())
def test_value_based_policy_categorical(device):
torch.manual_seed(0)
Expand Down
41 changes: 22 additions & 19 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def _create_mock_actor(
action_spec = OneHotDiscreteTensorSpec(action_dim)
elif action_spec_type == "categorical":
action_spec = DiscreteTensorSpec(action_dim)
elif action_spec_type == "nd_bounded":
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
# elif action_spec_type == "nd_bounded":
# action_spec = BoundedTensorSpec(
# -torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
# )
else:
raise ValueError(f"Wrong {action_spec_type}")

Expand All @@ -162,6 +162,7 @@ def _create_mock_actor(
chosen_action_value=None,
shape=[],
),
action_space=action_spec_type,
module=module,
).to(device)
return actor
Expand All @@ -178,8 +179,12 @@ def _create_mock_distributional_actor(
is_nn_module=False,
):
# Actor
var_nums = None
if action_spec_type == "mult_one_hot":
action_spec = MultiOneHotDiscreteTensorSpec([atoms] * action_dim)
action_spec = MultiOneHotDiscreteTensorSpec(
[action_dim // 2, action_dim // 2]
)
var_nums = action_spec.nvec
elif action_spec_type == "one_hot":
action_spec = OneHotDiscreteTensorSpec(action_dim)
elif action_spec_type == "categorical":
Expand All @@ -201,9 +206,8 @@ def _create_mock_distributional_actor(
),
module=module,
support=support,
action_space="categorical"
if isinstance(action_spec, DiscreteTensorSpec)
else "one_hot",
action_space=action_spec_type,
var_nums=var_nums,
)
return actor

Expand All @@ -230,7 +234,7 @@ def _create_mock_data_dqn(

if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=True)
action = torch.argmax(action, -1, keepdim=False)
reward = torch.randn(batch, 1)
done = torch.zeros(batch, 1, dtype=torch.bool)
td = TensorDict(
Expand Down Expand Up @@ -274,13 +278,16 @@ def _create_seq_mock_data_dqn(
action_value = torch.randn(batch, T, action_dim, device=device)
action = (action_value == action_value.max(-1, True)[0]).to(torch.long)

if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=True)
# action_value = action_value.unsqueeze(-1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=False)
action = action.masked_fill_(~mask, 0.0)
else:
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
Expand All @@ -291,17 +298,15 @@ def _create_seq_mock_data_dqn(
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
"action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
"action": action,
"action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
)
return td

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"action_spec_type", ("nd_bounded", "one_hot", "categorical")
)
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_dqn(self, delay_value, device, action_spec_type, td_est):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -344,9 +349,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est):
@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"action_spec_type", ("nd_bounded", "one_hot", "categorical")
)
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(
Expand Down
Loading

0 comments on commit 32339da

Please sign in to comment.