Skip to content

Commit

Permalink
[Feature] Heterogeneous Environments compatibility (pytorch#1411)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
matteobettini and vmoens authored Aug 4, 2023
1 parent 83dfff3 commit b210665
Show file tree
Hide file tree
Showing 6 changed files with 459 additions and 37 deletions.
195 changes: 194 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from tensordict.utils import expand_right, NestedKey

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
Expand All @@ -19,6 +19,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import consolidate_spec
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based.common import ModelBasedEnvBase

Expand Down Expand Up @@ -1290,3 +1291,195 @@ def _step(
device=self.device,
)
return tensordict.select().set("next", tensordict)


class HeteroCountingEnvPolicy:
def __init__(self, full_action_spec: TensorSpec, count: bool = True):
self.full_action_spec = full_action_spec
self.count = count

def __call__(self, td: TensorDictBase) -> TensorDictBase:
action_td = self.full_action_spec.zero()
if self.count:
action_td.apply_(lambda x: x + 1)
return td.update(action_td)


class HeteroCountingEnv(EnvBase):
"""A heterogeneous, counting Env."""

def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
super().__init__(**kwargs)
self.n_nested_dim = 3
self.max_steps = max_steps
self.start_val = start_val

count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int)
count[:] = self.start_val

self.register_buffer("count", count)

obs_specs = []
action_specs = []
for index in range(self.n_nested_dim):
obs_specs.append(self.get_agent_obs_spec(index))
action_specs.append(self.get_agent_action_spec(index))
obs_specs = torch.stack(obs_specs, dim=0)
obs_spec_unlazy = consolidate_spec(obs_specs)
action_specs = torch.stack(action_specs, dim=0)

self.unbatched_observation_spec = CompositeSpec(
lazy=obs_spec_unlazy,
state=UnboundedContinuousTensorSpec(
shape=(
64,
64,
3,
)
),
)

self.unbatched_action_spec = CompositeSpec(
lazy=action_specs,
)
self.unbatched_reward_spec = CompositeSpec(
{
"lazy": CompositeSpec(
{
"reward": UnboundedContinuousTensorSpec(
shape=(self.n_nested_dim, 1)
)
},
shape=(self.n_nested_dim,),
)
}
)
self.unbatched_done_spec = CompositeSpec(
{
"lazy": CompositeSpec(
{
"done": DiscreteTensorSpec(
n=2,
shape=(self.n_nested_dim, 1),
dtype=torch.bool,
),
},
shape=(self.n_nested_dim,),
)
}
)

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def get_agent_obs_spec(self, i):
camera = BoundedTensorSpec(minimum=0, maximum=200, shape=(7, 7, 3))
vector_3d = UnboundedContinuousTensorSpec(shape=(3,))
vector_2d = UnboundedContinuousTensorSpec(shape=(2,))
lidar = BoundedTensorSpec(minimum=0, maximum=5, shape=(8,))

tensor_0 = UnboundedContinuousTensorSpec(shape=(1,))
tensor_1 = BoundedTensorSpec(minimum=0, maximum=3, shape=(1, 2))
tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3))

if i == 0:
return CompositeSpec(
{
"camera": camera,
"lidar": lidar,
"vector": vector_3d,
"tensor_0": tensor_0,
}
)
elif i == 1:
return CompositeSpec(
{
"camera": camera,
"lidar": lidar,
"vector": vector_2d,
"tensor_1": tensor_1,
}
)
elif i == 2:
return CompositeSpec(
{
"camera": camera,
"vector": vector_2d,
"tensor_2": tensor_2,
}
)
else:
raise ValueError(f"Index {i} undefined for index 3")

def get_agent_action_spec(self, i):
action_3d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(3,))
action_2d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(2,))

# Some have 2d action and some 3d
# TODO Introduce composite heterogeneous actions
if i == 0:
ret = action_3d
elif i == 1:
ret = action_2d
elif i == 2:
ret = action_2d
else:
raise ValueError(f"Index {i} undefined for index 3")

return CompositeSpec({"action": ret})

def _reset(
self,
tensordict: TensorDictBase = None,
**kwargs,
) -> TensorDictBase:
if tensordict is not None and "_reset" in tensordict.keys():
_reset = tensordict.get("_reset").squeeze(-1).any(-1)
self.count[_reset] = self.start_val
else:
self.count[:] = self.start_val

reset_td = self.observation_spec.zero()
reset_td.apply_(lambda x: x + expand_right(self.count, x.shape))
reset_td.update(self.output_spec["_done_spec"].zero())

assert reset_td.batch_size == self.batch_size

return reset_td

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
actions = torch.zeros_like(self.count.squeeze(-1), dtype=torch.bool)
for i in range(self.n_nested_dim):
action = tensordict["lazy"][..., i]["action"]
action = action[..., 0].to(torch.bool)
actions += action

self.count += actions.unsqueeze(-1).to(torch.int)

td = self.observation_spec.zero()
td.apply_(lambda x: x + expand_right(self.count, x.shape))
td.update(self.output_spec["_done_spec"].zero())
td.update(self.output_spec["_reward_spec"].zero())

assert td.batch_size == self.batch_size
td[self.done_key] = expand_right(
self.count > self.max_steps, self.done_spec.shape
)

return td.select().set("next", td)

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)
Loading

0 comments on commit b210665

Please sign in to comment.