Skip to content

Commit

Permalink
[BugFix] Parametric collectors (#1303)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
Co-authored-by: vmoens <vincentmoens@gmail.com>
Co-authored-by: Matteo Bettini <matbet@meta.com>
3 people authored Jun 29, 2023
1 parent f6efd77 commit 36d2478
Showing 5 changed files with 325 additions and 42 deletions.
169 changes: 146 additions & 23 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
@@ -7,13 +7,16 @@
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey

from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs.common import EnvBase
@@ -941,6 +944,15 @@ def forward(self, observation, action):
return self.linear(torch.cat([observation, action], dim=-1))


class CountingEnvCountPolicy:
def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
self.action_spec = action_spec
self.action_key = action_key

def __call__(self, td: TensorDictBase) -> TensorDictBase:
return td.set(self.action_key, self.action_spec.zero() + 1)


class CountingEnv(EnvBase):
"""An env that is done after a given number of steps.
@@ -1011,7 +1023,7 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get("action")
action = tensordict.get(self.action_key)
self.count += action.to(torch.int).to(self.device)
tensordict = TensorDict(
source={
@@ -1025,38 +1037,149 @@ def _step(
return tensordict.select().set("next", tensordict)


class NestedRewardEnv(CountingEnv):
class NestedCountingEnv(CountingEnv):
# an env with nested reward and done states
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
def __init__(
self,
max_steps: int = 5,
start_val: int = 0,
nest_obs_action: bool = True,
nest_done: bool = True,
nest_reward: bool = True,
nested_dim: int = 3,
**kwargs,
):
super().__init__(max_steps=max_steps, start_val=start_val, **kwargs)
self.observation_spec = CompositeSpec(
{("data", "states"): self.observation_spec["observation"].clone()},
shape=self.batch_size,
)
self.reward_spec = CompositeSpec(
{("data", "reward"): self.reward_spec.clone()}, shape=self.batch_size
)
self.done_spec = CompositeSpec(
{("data", "done"): self.done_spec.clone()}, shape=self.batch_size
)

self.nested_dim = nested_dim

self.nested_obs_action = nest_obs_action
self.nested_done = nest_done
self.nested_reward = nest_reward

if self.nested_obs_action:
self.observation_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"states": self.observation_spec["observation"]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
},
shape=(
*self.batch_size,
self.nested_dim,
),
)
},
shape=self.batch_size,
)
self.action_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"action": self.action_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
)
},
shape=(
*self.batch_size,
self.nested_dim,
),
)
},
shape=self.batch_size,
)

if self.nested_reward:
self.reward_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"reward": self.reward_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
)
},
shape=(
*self.batch_size,
self.nested_dim,
),
)
},
shape=self.batch_size,
)

if self.nested_done:
self.done_spec = CompositeSpec(
{
"data": CompositeSpec(
{
"done": self.done_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
)
},
shape=(
*self.batch_size,
self.nested_dim,
),
)
},
shape=self.batch_size,
)

def _reset(self, td):
if self.nested_done and td is not None and "_reset" in td.keys():
td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool)
td = super()._reset(td)
td[self.done_key] = td["done"]
del td["done"]
td["data", "states"] = td["observation"]
del td["observation"]
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
)
del td["done"]
if self.nested_obs_action:
td["data", "states"] = (
td["observation"]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
)
del td["observation"]
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
return td

def _step(self, td):
if self.nested_obs_action:
td["data"].batch_size = self.batch_size
td[self.action_key] = td[self.action_key].max(-2)[0]
td_root = super()._step(td)
if self.nested_obs_action:
td[self.action_key] = (
td[self.action_key]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
)
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
td = td_root["next"]
td[self.reward_key] = td["reward"]
del td["reward"]
td[self.done_key] = td["done"]
del td["done"]
td["data", "states"] = td["observation"]
del td["observation"]
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
)
del td["done"]
if self.nested_obs_action:
td["data", "states"] = (
td["observation"]
.unsqueeze(-1)
.expand(*self.batch_size, self.nested_dim, 1)
)
del td["observation"]
if self.nested_reward:
td[self.reward_key] = (
td["reward"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
)
del td["reward"]
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
return td_root


126 changes: 124 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
@@ -14,16 +14,18 @@
ContinuousActionVecMockEnv,
CountingBatchedEnv,
CountingEnv,
CountingEnvCountPolicy,
DiscreteActionConvMockEnv,
DiscreteActionConvPolicy,
DiscreteActionVecMockEnv,
DiscreteActionVecPolicy,
MockSerialEnv,
NestedCountingEnv,
)
from tensordict.nn import TensorDictModule
from tensordict.tensordict import assert_allclose_td, TensorDict
from torch import nn
from torchrl._utils import seed_generator
from torchrl._utils import prod, seed_generator
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import (
_Interruptor,
@@ -33,7 +35,14 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase, EnvCreator, ParallelEnv, SerialEnv, StepCounter
from torchrl.envs import (
EnvBase,
EnvCreator,
InitTracker,
ParallelEnv,
SerialEnv,
StepCounter,
)
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule
@@ -1346,6 +1355,119 @@ def test_reset_heterogeneous_envs():
).all()


class TestNestedEnvsCollector:
def test_multi_collector_nested_env_consistency(self, seed=1):
env = NestedCountingEnv()
torch.manual_seed(seed)
env_fn = lambda: TransformedEnv(env, InitTracker())
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)

ccollector = MultiaSyncDataCollector(
create_env_fn=[env_fn],
policy=policy,
frames_per_batch=20,
total_frames=100,
device="cpu",
)
for i, d in enumerate(ccollector):
if i == 0:
c1 = d
elif i == 1:
c2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(c1, c2)
ccollector.shutdown()

ccollector = MultiSyncDataCollector(
create_env_fn=[env_fn],
policy=policy,
frames_per_batch=20,
total_frames=100,
device="cpu",
)
for i, d in enumerate(ccollector):
if i == 0:
d1 = d
elif i == 1:
d2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(d1, d2)
ccollector.shutdown()

assert_allclose_td(c1, d1)
assert_allclose_td(c2, d2)

@pytest.mark.parametrize("nested_obs_action", [True, False])
@pytest.mark.parametrize("nested_done", [True, False])
@pytest.mark.parametrize("nested_reward", [True, False])
def test_collector_nested_env_combinations(
self,
nested_obs_action,
nested_done,
nested_reward,
seed=1,
frames_per_batch=20,
):
env = NestedCountingEnv(
nest_reward=nested_reward,
nest_done=nested_done,
nest_obs_action=nested_obs_action,
)
torch.manual_seed(seed)
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
ccollector = SyncDataCollector(
create_env_fn=env,
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
)

for _td in ccollector:
break
ccollector.shutdown()

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 2)])
def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20):
from mocking_classes import CountingEnvCountPolicy, NestedCountingEnv

env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim)
env_fn = lambda: NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim)
torch.manual_seed(0)
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
policy(env.reset())
ccollector = SyncDataCollector(
create_env_fn=env_fn,
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
)

for _td in ccollector:
break
ccollector.shutdown()

# assert ("data","reward") not in td.keys(True) # this can be activates once step_mdp is fixed for nested keys
assert _td.batch_size == (*batch_size, frames_per_batch // prod(batch_size))
assert _td["data"].batch_size == (
*batch_size,
frames_per_batch // prod(batch_size),
nested_dim,
)
assert _td["next", "data"].batch_size == (
*batch_size,
frames_per_batch // prod(batch_size),
nested_dim,
)


@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda")
class TestUpdateParams:
class DummyEnv(EnvBase):
Loading
Oops, something went wrong.

0 comments on commit 36d2478

Please sign in to comment.