Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-env-nontensor
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 19, 2024
2 parents b443095 + 47a2627 commit e1b38b3
Show file tree
Hide file tree
Showing 19 changed files with 456 additions and 92 deletions.
6 changes: 6 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ Here are a few examples, starting with the generic :class:`~torchrl.data.replay_
>>> rb.add("a string!") # first element will be a string
>>> rb.extend([30, None]) # element [1] is an int, [2] is None

The main entry points to write onto a buffer are :meth:`~torchrl.data.ReplayBuffer.add` and
:meth:`~torchrl.data.ReplayBuffer.extend`.
One can also use :meth:`~torchrl.data.ReplayBuffer.__setitem__`, in which case the data is written
where indicated without updating the length or cursor of the buffer. This can be useful when sampling
items from the buffer and them updating their values in-place afterwards.

Using a :class:`~torchrl.data.replay_buffers.TensorStorage` we tell our RB that
we want the storage to be contiguous, which is by far more efficient but also
more restrictive:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def main(cfg: "DictConfig"): # noqa: F821

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
eval_env.close()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def main(cfg: "DictConfig"): # noqa: F821
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")

collector.shutdown()
eval_env.close()
train_env.close()


if __name__ == "__main__":
Expand Down
117 changes: 116 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
Expand Down Expand Up @@ -42,7 +43,13 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict
from tensordict import (
assert_allclose_td,
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential

from torch import nn
Expand All @@ -57,7 +64,9 @@
from torchrl.data import (
CompositeSpec,
LazyTensorStorage,
NonTensorSpec,
ReplayBuffer,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -67,6 +76,7 @@
ParallelEnv,
SerialEnv,
StepCounter,
Transform,
)
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend
from torchrl.envs.transforms import TransformedEnv, VecNorm
Expand Down Expand Up @@ -2641,6 +2651,111 @@ def test_dynamic_multiasync_collector(self):
assert data.names[-1] == "time"


@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
class TestCollectorsNonTensor:
class AddNontTensorData(Transform):
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict["nt"] = f"a string! - {tensordict.get('step_count').item()}"
return tensordict

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return tensordict_reset.set("nt", NonTensorData("reset!"))

def transform_observation_spec(
self, observation_spec: TensorSpec
) -> TensorSpec:
observation_spec["nt"] = NonTensorSpec(shape=())
return observation_spec

@classmethod
def make_env(cls):
return (
GymEnv(CARTPOLE_VERSIONED())
.append_transform(StepCounter())
.append_transform(cls.AddNontTensorData())
)

def test_simple(self):
torch.manual_seed(0)
env = self.make_env()
env.set_seed(0)
collector = SyncDataCollector(env, frames_per_batch=10, total_frames=200)
result = []
for data in collector:
result.append(data)
result = torch.cat(result)
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i

@pytest.mark.parametrize("use_buffers", [True, False])
def test_sync(self, use_buffers):
torch.manual_seed(0)
collector = MultiSyncDataCollector(
[self.make_env, self.make_env],
frames_per_batch=10,
total_frames=200,
cat_results="stack",
use_buffers=use_buffers,
)
try:
result = []
for data in collector:
result.append(data)
results = torch.cat(result)
for result in results.unbind(0):
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i
finally:
collector.shutdown()
del collector

@pytest.mark.parametrize("use_buffers", [True, False])
def test_async(self, use_buffers):
torch.manual_seed(0)
collector = MultiaSyncDataCollector(
[self.make_env, self.make_env],
frames_per_batch=10,
total_frames=200,
use_buffers=use_buffers,
)
try:
results = []
for data in collector:
results.append(data)
for result in results:
for i, val in enumerate(result["nt"][1:]):
if val == "a string! - 1":
assert result["nt"][i] == "reset!"
elif val.startswith("a string!"):
assert result["next", "nt"][i] == val
int1 = int(val.split(" - ")[-1])
int0 = int(result["nt"][i].split(" - ")[-1])
assert int0 + 1 == int1
elif val == "reset!":
assert result["next", "nt"][i + 1] == "a string! - 1", i
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
35 changes: 20 additions & 15 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,12 +1966,13 @@ def test_env_device(self, env_name, frame_skip, transformed_out, device):


@pytest.mark.skipif(not _has_brax, reason="brax not installed")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("envname", ["fast"])
class TestBrax:
@pytest.mark.parametrize("requires_grad", [False, True])
def test_brax_constructor(self, envname, requires_grad):
env0 = BraxEnv(envname, requires_grad=requires_grad)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad)
def test_brax_constructor(self, envname, requires_grad, device):
env0 = BraxEnv(envname, requires_grad=requires_grad, device=device)
env1 = BraxWrapper(env0._env, requires_grad=requires_grad, device=device)

env0.set_seed(0)
torch.manual_seed(0)
Expand All @@ -1994,12 +1995,12 @@ def test_brax_constructor(self, envname, requires_grad):
assert r1.requires_grad == requires_grad
assert_allclose_td(r0.data, r1.data)

def test_brax_seeding(self, envname):
def test_brax_seeding(self, envname, device):
final_seed = []
tdreset = []
tdrollout = []
for _ in range(2):
env = BraxEnv(envname)
env = BraxEnv(envname, device=device)
torch.manual_seed(0)
np.random.seed(0)
final_seed.append(env.set_seed(0))
Expand All @@ -2012,8 +2013,8 @@ def test_brax_seeding(self, envname):
assert_allclose_td(*tdrollout)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_batch_size(self, envname, batch_size):
env = BraxEnv(envname, batch_size=batch_size)
def test_brax_batch_size(self, envname, batch_size, device):
env = BraxEnv(envname, batch_size=batch_size, device=device)
env.set_seed(0)
tdreset = env.reset()
tdrollout = env.rollout(max_steps=50)
Expand All @@ -2023,8 +2024,8 @@ def test_brax_batch_size(self, envname, batch_size):
assert tdrollout.batch_size[:-1] == batch_size

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_spec_rollout(self, envname, batch_size):
env = BraxEnv(envname, batch_size=batch_size)
def test_brax_spec_rollout(self, envname, batch_size, device):
env = BraxEnv(envname, batch_size=batch_size, device=device)
env.set_seed(0)
check_env_specs(env)

Expand All @@ -2036,7 +2037,7 @@ def test_brax_spec_rollout(self, envname, batch_size):
False,
],
)
def test_brax_consistency(self, envname, batch_size, requires_grad):
def test_brax_consistency(self, envname, batch_size, requires_grad, device):
import jax
import jax.numpy as jnp
from torchrl.envs.libs.jax_utils import (
Expand All @@ -2045,7 +2046,9 @@ def test_brax_consistency(self, envname, batch_size, requires_grad):
_tree_flatten,
)

env = BraxEnv(envname, batch_size=batch_size, requires_grad=requires_grad)
env = BraxEnv(
envname, batch_size=batch_size, requires_grad=requires_grad, device=device
)
env.set_seed(1)
rollout = env.rollout(10)

Expand All @@ -2064,9 +2067,9 @@ def test_brax_consistency(self, envname, batch_size, requires_grad):
torch.testing.assert_close(t1, t2)

@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
def test_brax_grad(self, envname, batch_size):
def test_brax_grad(self, envname, batch_size, device):
batch_size = (1,)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=True)
env = BraxEnv(envname, batch_size=batch_size, requires_grad=True, device=device)
env.set_seed(0)
td1 = env.reset()
action = torch.randn(env.action_spec.shape)
Expand All @@ -2080,10 +2083,12 @@ def test_brax_grad(self, envname, batch_size):
@pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)])
@pytest.mark.parametrize("parallel", [False, True])
def test_brax_parallel(
self, envname, batch_size, parallel, maybe_fork_ParallelEnv, n=1
self, envname, batch_size, parallel, maybe_fork_ParallelEnv, device, n=1
):
def make_brax():
env = BraxEnv(envname, batch_size=batch_size, requires_grad=False)
env = BraxEnv(
envname, batch_size=batch_size, requires_grad=False, device=device
)
env.set_seed(1)
return env

Expand Down
Loading

0 comments on commit e1b38b3

Please sign in to comment.