Skip to content

Commit

Permalink
[Refactor] Nested reward and done specs (#1115)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 29, 2023
1 parent 737f614 commit f3e9a1d
Show file tree
Hide file tree
Showing 30 changed files with 1,295 additions and 581 deletions.
20 changes: 12 additions & 8 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,23 @@ Each env will have the following attributes:
This is especially useful for transforms (see below). For parametric environments (e.g.
model-based environments), the device does represent the hardware that will be used to
compute the operations.
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the input keys (:obj:`"action"` and others).
- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the output keys (:obj:`"observation"`, :obj:`"reward"` and :obj:`"done"`).
- :obj:`env.observation_spec`: a :class:`~torchrl.data.CompositeSpec` object
containing all the observation key-spec pairs.
This is a pointer to ``env.output_spec["observation"]``.
- :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object
containing all the input key-spec pairs (except action). For most stateful
environments, this container will be empty.
- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object
representing the action spec. This is a pointer to ``env.input_spec["action"]``.
representing the action spec.
- :obj:`env.reward_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the reward spec. This is a pointer to ``env.output_spec["reward"]``.
the reward spec.
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the done-flag spec. This is a pointer to ``env.output_spec["done"]``.
the done-flag spec.
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the input keys (:obj:`"_action_spec"` and :obj:`"_state_spec"`).
It is locked and should not be modified directly.
- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
all the output keys (:obj:`"_observation_spec"`, :obj:`"_reward_spec"` and :obj:`"_done_spec"`).
It is locked and should not be modified directly.

Importantly, the environment spec shapes should contain the batch size, e.g.
an environment with :obj:`env.batch_size == torch.Size([4])` should have
Expand Down
30 changes: 14 additions & 16 deletions examples/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,21 @@ def make_a2c_models(cfg):
def make_a2c_modules_state(proof_environment):

# Define input shape
env_specs = proof_environment.specs
input_shape = env_specs["output_spec"]["observation"]["observation_vector"].shape
input_shape = proof_environment.observation_spec["observation_vector"].shape

# Define distribution class and kwargs
continuous_actions = False
if isinstance(env_specs["input_spec"]["action"].space, DiscreteBox):
num_outputs = env_specs["input_spec"]["action"].space.n
if isinstance(proof_environment.action_spec.space, DiscreteBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
continuous_actions = True
num_outputs = env_specs["input_spec"]["action"].shape[-1] * 2
num_outputs = proof_environment.action_spec.shape[-1] * 2
distribution_class = TanhNormal
distribution_kwargs = {
"min": env_specs["input_spec"]["action"].space.minimum,
"max": env_specs["input_spec"]["action"].space.maximum,
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
"tanh_loc": False,
}

Expand Down Expand Up @@ -313,7 +312,7 @@ def make_a2c_modules_state(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["loc", "scale"] if continuous_actions else ["logits"],
spec=CompositeSpec(action=env_specs["input_spec"]["action"]),
spec=CompositeSpec(action=proof_environment.action_spec),
safe=True,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
Expand All @@ -340,20 +339,19 @@ def make_a2c_modules_state(proof_environment):
def make_a2c_modules_pixels(proof_environment):

# Define input shape
env_specs = proof_environment.specs
input_shape = env_specs["output_spec"]["observation"]["pixels"].shape
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(env_specs["input_spec"]["action"].space, DiscreteBox):
num_outputs = env_specs["input_spec"]["action"].space.n
if isinstance(proof_environment.action_spec.space, DiscreteBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = env_specs["input_spec"]["action"].shape
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"min": env_specs["input_spec"]["action"].space.minimum,
"max": env_specs["input_spec"]["action"].space.maximum,
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
}

# Define input keys
Expand Down Expand Up @@ -399,7 +397,7 @@ def make_a2c_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=CompositeSpec(action=env_specs["input_spec"]["action"]),
spec=CompositeSpec(action=proof_environment.action_spec),
safe=True,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
Expand Down
30 changes: 14 additions & 16 deletions examples/ppo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,22 +278,21 @@ def make_ppo_models(cfg):
def make_ppo_modules_state(proof_environment):

# Define input shape
env_specs = proof_environment.specs
input_shape = env_specs["output_spec"]["observation"]["observation_vector"].shape
input_shape = proof_environment.observation_spec["observation_vector"].shape

# Define distribution class and kwargs
continuous_actions = False
if isinstance(env_specs["input_spec"]["action"].space, DiscreteBox):
num_outputs = env_specs["input_spec"]["action"].space.n
if isinstance(proof_environment.action_spec.space, DiscreteBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
continuous_actions = True
num_outputs = env_specs["input_spec"]["action"].shape[-1] * 2
num_outputs = proof_environment.action_spec.shape[-1] * 2
distribution_class = TanhNormal
distribution_kwargs = {
"min": env_specs["input_spec"]["action"].space.minimum,
"max": env_specs["input_spec"]["action"].space.maximum,
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
"tanh_loc": False,
}

Expand Down Expand Up @@ -332,7 +331,7 @@ def make_ppo_modules_state(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["loc", "scale"] if continuous_actions else ["logits"],
spec=CompositeSpec(action=env_specs["input_spec"]["action"]),
spec=CompositeSpec(action=proof_environment.action_spec),
safe=True,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
Expand All @@ -353,20 +352,19 @@ def make_ppo_modules_state(proof_environment):
def make_ppo_modules_pixels(proof_environment):

# Define input shape
env_specs = proof_environment.specs
input_shape = env_specs["output_spec"]["observation"]["pixels"].shape
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(env_specs["input_spec"]["action"].space, DiscreteBox):
num_outputs = env_specs["input_spec"]["action"].space.n
if isinstance(proof_environment.action_spec.space, DiscreteBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = env_specs["input_spec"]["action"].shape
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"min": env_specs["input_spec"]["action"].space.minimum,
"max": env_specs["input_spec"]["action"].space.maximum,
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
}

# Define input keys
Expand Down Expand Up @@ -412,7 +410,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=CompositeSpec(action=env_specs["input_spec"]["action"]),
spec=CompositeSpec(action=proof_environment.action_spec),
safe=True,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
Expand Down
1 change: 0 additions & 1 deletion examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def main(cfg: "DictConfig"): # noqa: F821
q_loss = None

for i, tensordict in enumerate(collector):

# update weights of the inference policy
collector.update_policy_weights_()

Expand Down
Loading

0 comments on commit f3e9a1d

Please sign in to comment.