Skip to content

Commit

Permalink
[BugFix] action_spec_unbatched whenever necessary
Browse files Browse the repository at this point in the history
ghstack-source-id: ec87794dabaf5023dac85cfc898a7c000e93331d
Pull Request resolved: #2592
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent a47b32c commit d30599e
Show file tree
Hide file tree
Showing 25 changed files with 191 additions and 116 deletions.
4 changes: 2 additions & 2 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
"low": env.action_spec_unbatched.space.low,
"high": env.action_spec_unbatched.space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low.to(device),
"high": proof_environment.action_spec.space.high.to(device),
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": True,
}
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg):
def make_cql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

action_spec = train_env.action_spec
action_spec = train_env.action_spec_unbatched

actor_net, q_net = make_cql_modules_state(model_cfg, eval_env)
in_keys = ["observation"]
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def make_crossQ_agent(cfg, train_env, device):
"""Make CrossQ agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def make_dt_model(cfg):
make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1
)

action_spec = proof_environment.action_spec
action_spec = proof_environment.action_spec_unbatched
for key, value in proof_environment.observation_spec.items():
if key == "observation":
state_dim = value.shape[-1]
Expand Down
10 changes: 9 additions & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)

# mixed precision training
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.utils import clip_grad_norm_
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -321,6 +321,14 @@ def compile_rssms(module):

t_collect_init = time.time()

test_env.close()
train_env.close()
collector.shutdown()

del test_env
del train_env
del collector


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
model_cfg = cfg.model

in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)

out_keys = ["loc", "scale"]
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821
# Policy
net = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=env.action_spec.space.n,
n_agent_outputs=env.full_action_spec["agents", "action"].space.n,
n_agents=env.n_agents,
centralised=False,
share_params=cfg.model.shared_parameters,
Expand All @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "param")],
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=False,
)

policy_explore = TensorDictSequential(
policy,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821
# Policy
net = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=env.action_spec.space.n,
n_agent_outputs=env.full_action_spec["agents", "action"].space.n,
n_agents=env.n_agents,
centralised=False,
share_params=cfg.model.shared_parameters,
Expand All @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821

policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=OneHotCategorical
Expand Down Expand Up @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.unbatched_action_spec,
action_spec=env.full_action_spec_unbatched,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand All @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.unbatched_action_spec,
action_space=env.full_action_spec_unbatched,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def make_redq_model(
default_policy_scale = cfg.network.default_policy_scale
gSDE = cfg.exploration.gSDE

action_spec = proof_environment.action_spec
action_spec = proof_environment.action_spec_unbatched

if actor_net_kwargs is None:
actor_net_kwargs = {}
Expand Down
4 changes: 1 addition & 3 deletions sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
"""Make SAC agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
Expand Down
Loading

0 comments on commit d30599e

Please sign in to comment.