Skip to content

Commit

Permalink
Explicit CatFrame for pixels experiments and REDQ from pixels (pytorc…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 23, 2022
1 parent 2664f3e commit 9d2c853
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 33 deletions.
4 changes: 3 additions & 1 deletion examples/ddpg/configs/humanoid_pixels.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ total_frames=5000000

from_pixels
activation=elu
lr=0.001
lr=0.0002
weight_decay=2e-5

catframes=4
26 changes: 26 additions & 0 deletions examples/redq/configs/humanoid_pixels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
env_name=humanoid
env_task=walk
env_library=dm_control
async_collection
record_video
prb
normalize_rewards_online
normalize_rewards_online_scale=10
exp_name=humanoid
tanh_loc

num_workers=32
env_per_collector=8

frame_skip=2
frames_per_batch=500
optim_steps_per_batch=80
batch_size=128
total_frames=5000000

from_pixels
activation=elu
lr=0.0002
weight_decay=2e-5

catframes=4
29 changes: 23 additions & 6 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def _assert_keys_match(td, expeceted_keys):
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("noisy", [tuple(), ("--noisy",)])
@pytest.mark.parametrize("distributional", [tuple(), ("--distributional",)])
@pytest.mark.parametrize("from_pixels", [tuple(), ("--from_pixels",)])
@pytest.mark.parametrize(
"from_pixels", [tuple(), ("--from_pixels", "--catframes", "4")]
)
def test_dqn_maker(device, noisy, distributional, from_pixels):
flags = list(noisy + distributional + from_pixels) + ["--env_name=CartPole-v1"]
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -84,7 +86,9 @@ def test_dqn_maker(device, noisy, distributional, from_pixels):

@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("from_pixels", [tuple(), ("--from_pixels",)])
@pytest.mark.parametrize(
"from_pixels", [tuple(), ("--from_pixels", "--catframes", "4")]
)
@pytest.mark.parametrize("gsde", [tuple(), ("--gSDE",)])
@pytest.mark.parametrize("exploration", ["random", "mode"])
def test_ddpg_maker(device, from_pixels, gsde, exploration):
Expand Down Expand Up @@ -145,7 +149,9 @@ def test_ddpg_maker(device, from_pixels, gsde, exploration):

@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("from_pixels", [tuple(), ("--from_pixels",)])
@pytest.mark.parametrize(
"from_pixels", [tuple(), ("--from_pixels", "--catframes", "4")]
)
@pytest.mark.parametrize("gsde", [tuple(), ("--gSDE",)])
@pytest.mark.parametrize("shared_mapping", [tuple(), ("--shared_mapping",)])
@pytest.mark.parametrize("exploration", ["random", "mode"])
Expand Down Expand Up @@ -260,6 +266,8 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):
parser = parser_env_args(parser)
parser = parser_model_args_continuous(parser, algorithm="SAC")
args = parser.parse_args(flags)
if from_pixels:
args.catframes = 4

env_maker = (
ContinuousActionConvMockEnvNumpy if from_pixels else ContinuousActionVecMockEnv
Expand Down Expand Up @@ -337,7 +345,9 @@ def test_sac_make(device, gsde, tanh_loc, from_pixels, exploration):


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("from_pixels", [tuple()])
@pytest.mark.parametrize(
"from_pixels", [tuple(), ("--from_pixels", "--catframes", "4")]
)
@pytest.mark.parametrize("gsde", [tuple(), ("--gSDE",)])
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.parametrize("exploration", ["random", "mode"])
Expand Down Expand Up @@ -372,14 +382,18 @@ def test_redq_make(device, from_pixels, gsde, exploration):
actor(td)
expected_keys = [
"done",
"observation_vector",
"action",
"sample_log_prob",
"loc",
"scale",
]
if len(gsde):
expected_keys += ["_eps_gSDE"]
if from_pixels:
expected_keys += ["hidden", "pixels"]
else:
expected_keys += ["observation_vector"]

try:
_assert_keys_match(td, expected_keys)
except AssertionError:
Expand All @@ -397,7 +411,6 @@ def test_redq_make(device, from_pixels, gsde, exploration):
qvalue(td)
expected_keys = [
"done",
"observation_vector",
"action",
"sample_log_prob",
"state_action_value",
Expand All @@ -406,6 +419,10 @@ def test_redq_make(device, from_pixels, gsde, exploration):
]
if len(gsde):
expected_keys += ["_eps_gSDE"]
if from_pixels:
expected_keys += ["hidden", "pixels"]
else:
expected_keys += ["observation_vector"]
try:
_assert_keys_match(td, expected_keys)
except AssertionError:
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,8 @@ def __init__(
ddpg_init_last_layer(self.mlp[-1], 6e-4)

def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
value = self.mlp(torch.cat([self.convnet(observation), action], -1))
hidden = torch.cat([self.convnet(observation), action], -1)
value = self.mlp(hidden)
return value


Expand Down
7 changes: 6 additions & 1 deletion torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,15 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
if args.noops:
env.append_transform(NoopResetEnv(env, args.noops))
if from_pixels:
if not args.catframes:
raise RuntimeError(
"this env builder currently only accepts positive catframes values"
"when pixels are being used."
)
env.append_transform(ToTensorImage())
env.append_transform(Resize(84, 84))
env.append_transform(GrayScale())
env.append_transform(CatFrames(keys=["next_pixels"]))
env.append_transform(CatFrames(N=args.catframes, keys=["next_pixels"]))
if stats is None:
obs_stats = {"loc": 0.0, "scale": 1.0}
else:
Expand Down
85 changes: 61 additions & 24 deletions torchrl/trainers/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,26 +1025,63 @@ def make_redq_model(
if qvalue_net_kwargs is None:
qvalue_net_kwargs = {}

if in_keys is None:
in_keys = ["observation_vector"]
linear_layer_class = torch.nn.Linear if not args.noisy else NoisyLinear

actor_net_kwargs_default = {
"num_cells": [args.actor_cells, args.actor_cells],
"out_features": (2 - gSDE) * action_spec.shape[-1],
"activation_class": ACTIVATIONS[args.activation],
}
actor_net_kwargs_default.update(actor_net_kwargs)
actor_net = MLP(**actor_net_kwargs_default)
out_features_actor = (2 - gSDE) * action_spec.shape[-1]
if args.from_pixels:
if in_keys is None:
in_keys_actor = ["pixels"]
else:
in_keys_actor = in_keys
actor_net_kwargs_default = {
"mlp_net_kwargs": {
"layer_class": linear_layer_class,
"activation_class": ACTIVATIONS[args.activation],
},
"conv_net_kwargs": {"activation_class": ACTIVATIONS[args.activation]},
}
actor_net_kwargs_default.update(actor_net_kwargs)
actor_net = DdpgCnnActor(out_features_actor, **actor_net_kwargs_default)
gSDE_state_key = "hidden"
out_keys_actor = ["param", "hidden"]

qvalue_net_kwargs_default = {
"num_cells": [args.qvalue_cells, args.qvalue_cells],
"out_features": 1,
"activation_class": ACTIVATIONS[args.activation],
}
qvalue_net_kwargs_default.update(qvalue_net_kwargs)
qvalue_net = MLP(
**qvalue_net_kwargs_default,
)
value_net_default_kwargs = {
"mlp_net_kwargs": {
"layer_class": linear_layer_class,
"activation_class": ACTIVATIONS[args.activation],
},
"conv_net_kwargs": {"activation_class": ACTIVATIONS[args.activation]},
}
value_net_default_kwargs.update(qvalue_net_kwargs)

in_keys_qvalue = ["pixels", "action"]
qvalue_net = DdpgCnnQNet(**value_net_default_kwargs)
else:
if in_keys is None:
in_keys_actor = ["observation_vector"]
else:
in_keys_actor = in_keys

actor_net_kwargs_default = {
"num_cells": [args.actor_cells, args.actor_cells],
"out_features": out_features_actor,
"activation_class": ACTIVATIONS[args.activation],
}
actor_net_kwargs_default.update(actor_net_kwargs)
actor_net = MLP(**actor_net_kwargs_default)
out_keys_actor = ["param"]
gSDE_state_key = in_keys_actor[0]

qvalue_net_kwargs_default = {
"num_cells": [args.qvalue_cells, args.qvalue_cells],
"out_features": 1,
"activation_class": ACTIVATIONS[args.activation],
}
qvalue_net_kwargs_default.update(qvalue_net_kwargs)
qvalue_net = MLP(
**qvalue_net_kwargs_default,
)
in_keys_qvalue = in_keys_actor + ["action"]

dist_class = TanhNormal
dist_kwargs = {
Expand All @@ -1059,17 +1096,17 @@ def make_redq_model(
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=args.scale_lb,
)
in_keys_actor = in_keys
actor_module = TDModule(
actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"]
actor_net,
in_keys=in_keys_actor,
out_keys=["loc", "scale"] + out_keys_actor[1:],
)

else:
gSDE_state_key = in_keys[0]
actor_module = TDModule(
actor_net,
in_keys=in_keys,
out_keys=["action"], # will be overwritten
in_keys=in_keys_actor,
out_keys=["action"] + out_keys_actor[1:], # will be overwritten
)

if action_spec.domain == "continuous":
Expand Down Expand Up @@ -1103,7 +1140,7 @@ def make_redq_model(
return_log_prob=True,
)
qvalue = ValueOperator(
in_keys=["action"] + in_keys,
in_keys=in_keys_qvalue,
module=qvalue_net,
)
model = nn.ModuleList([actor, qvalue]).to(device)
Expand Down

0 comments on commit 9d2c853

Please sign in to comment.