Skip to content

Commit

Permalink
[Refactor] Faster envs (2) (pytorch#1457)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 1, 2023
1 parent 7ae6140 commit e3b3879
Show file tree
Hide file tree
Showing 22 changed files with 495 additions and 466 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_libs/scripts_gym/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_

export DISPLAY=':99.0'
Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 &
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym" --error-for-skips
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips
coverage combine
coverage xml -i
3 changes: 2 additions & 1 deletion .circleci/unittest/linux_libs/scripts_gym/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ conda env config vars set \
NVIDIA_PATH=/usr/src/nvidia-470.63.01 \
MUJOCO_PY_MJKEY_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/mjkey.txt \
MUJOCO_PY_MUJOCO_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/linux/mujoco210 \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/pytorch/rl/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin
# LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin

# LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/src/nvidia-470.63.01 \

Expand Down
38 changes: 29 additions & 9 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,25 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
transformed_env = TransformedEnv(base_env)
transformed_env.append_transform(
RewardScaling(
loc=0, scale=env_cfg.reward_scaling, in_keys="reward", standard_normal=False
loc=0,
scale=env_cfg.reward_scaling,
in_keys=["reward"],
standard_normal=False,
)
)
if train:
transformed_env.append_transform(
TargetReturn(
env_cfg.collect_target_return * env_cfg.reward_scaling,
out_keys=["return_to_go"],
out_keys=["return_to_go_single"],
mode=env_cfg.target_return_mode,
)
)
else:
transformed_env.append_transform(
TargetReturn(
env_cfg.eval_target_return * env_cfg.reward_scaling,
out_keys=["return_to_go"],
out_keys=["return_to_go_single"],
mode=env_cfg.target_return_mode,
)
)
Expand All @@ -107,7 +110,11 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
)
transformed_env.append_transform(obsnorm)
transformed_env.append_transform(
UnsqueezeTransform(-2, in_keys=["observation", "action", "return_to_go"])
UnsqueezeTransform(
-2,
in_keys=["observation", "action", "return_to_go_single"],
out_keys=["observation", "action", "return_to_go"],
)
)
transformed_env.append_transform(
CatFrames(
Expand Down Expand Up @@ -158,6 +165,8 @@ def make_collector(cfg, policy):
exclude_target_return = ExcludeTransform(
"return_to_go",
("next", "return_to_go"),
"return_to_go_single",
("next", "return_to_go_single"),
("next", "action"),
("next", "observation"),
"scale",
Expand All @@ -183,9 +192,15 @@ def make_collector(cfg, policy):


def make_offline_replay_buffer(rb_cfg, reward_scaling):
r2g = Reward2GoTransform(gamma=1.0, in_keys=["reward"], out_keys=["return_to_go"])
r2g = Reward2GoTransform(
gamma=1.0, in_keys=["reward"], out_keys=["return_to_go_single"]
)
reward_scale = RewardScaling(
loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False
loc=0,
scale=reward_scaling,
in_keys="return_to_go_single",
out_keys=["return_to_go"],
standard_normal=False,
)
crop_seq = RandomCropTensorDict(sub_seq_len=rb_cfg.stacked_frames, sample_dim=-1)

Expand Down Expand Up @@ -230,12 +245,17 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):


def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001):
r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"])
r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go_single"])
reward_scale = RewardScaling(
loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False
loc=0,
scale=reward_scaling,
in_keys=["return_to_go_single"],
out_keys=["return_to_go"],
standard_normal=False,
)
catframes = CatFrames(
in_keys=["return_to_go"],
in_keys=["return_to_go_single"],
out_keys=["return_to_go"],
N=rb_cfg.stacked_frames,
dim=-2,
padding="zeros",
Expand Down
32 changes: 11 additions & 21 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,7 @@ def _step(self, tensordict):
done = self.counter >= self.max_val
done = torch.tensor([done], dtype=torch.bool, device=self.device)
return TensorDict(
{
"next": TensorDict(
{"reward": n, "done": done, "observation": n.clone()}, batch_size=[]
)
},
{"reward": n, "done": done, "observation": n.clone()},
batch_size=[],
)

Expand Down Expand Up @@ -338,13 +334,7 @@ def _step(self, tensordict):
device=self.device,
)
return TensorDict(
{
"next": TensorDict(
{"reward": n, "done": done, "observation": n},
tensordict.batch_size,
device=self.device,
)
},
{"reward": n, "done": done, "observation": n},
batch_size=tensordict.batch_size,
device=self.device,
)
Expand Down Expand Up @@ -501,7 +491,7 @@ def _step(
done = torch.zeros_like(done).all(-1).unsqueeze(-1)
tensordict.set("reward", reward.to(torch.get_default_dtype()))
tensordict.set("done", done)
return tensordict.select().set("next", tensordict)
return tensordict


class ContinuousActionVecMockEnv(_MockEnv):
Expand Down Expand Up @@ -603,7 +593,7 @@ def _step(
done = reward = done.unsqueeze(-1)
tensordict.set("reward", reward.to(torch.get_default_dtype()))
tensordict.set("done", done)
return tensordict.select().set("next", tensordict)
return tensordict

def _obs_step(self, obs, a):
return obs + a / self.maxstep
Expand Down Expand Up @@ -1044,7 +1034,7 @@ def _step(
batch_size=self.batch_size,
device=self.device,
)
return tensordict.select().set("next", tensordict)
return tensordict


class NestedCountingEnv(CountingEnv):
Expand Down Expand Up @@ -1167,7 +1157,7 @@ def _step(self, td):
td = td.clone()
td["data"].batch_size = self.batch_size
td[self.action_key] = td[self.action_key].max(-2)[0]
td_root = super()._step(td)
next_td = super()._step(td)
if self.nested_obs_action:
td[self.action_key] = (
td[self.action_key]
Expand All @@ -1176,7 +1166,7 @@ def _step(self, td):
)
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
td = td_root["next"]
td = next_td
if self.nested_done:
td[self.done_key] = (
td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)
Expand All @@ -1196,7 +1186,7 @@ def _step(self, td):
del td["reward"]
if "data" in td.keys():
td["data"].batch_size = (*self.batch_size, self.nested_dim)
return td_root
return td


class CountingBatchedEnv(EnvBase):
Expand Down Expand Up @@ -1290,7 +1280,7 @@ def _step(
batch_size=self.batch_size,
device=self.device,
)
return tensordict.select().set("next", tensordict)
return tensordict


class HeteroCountingEnvPolicy:
Expand Down Expand Up @@ -1479,7 +1469,7 @@ def _step(
self.count > self.max_steps, self.done_spec.shape
)

return td.select().set("next", td)
return td

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)
Expand Down Expand Up @@ -1713,7 +1703,7 @@ def _step(
td.update(reward)

assert td.batch_size == self.batch_size
return td.select().set("next", td)
return td

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)
101 changes: 49 additions & 52 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,52 +353,52 @@ def make_env():
assert _data["next", "reward"].sum(-2).min() == -21


@pytest.mark.parametrize("num_env", [1, 2])
@pytest.mark.parametrize("env_name", ["vec"])
def test_collector_done_persist(num_env, env_name, seed=5):
if num_env == 1:

def env_fn(seed):
env = MockSerialEnv(device="cpu")
env.set_seed(seed)
return env

else:

def env_fn(seed):
def make_env(seed):
env = MockSerialEnv(device="cpu")
env.set_seed(seed)
return env

env = ParallelEnv(
num_workers=num_env,
create_env_fn=make_env,
create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)],
allow_step_when_done=True,
)
env.set_seed(seed)
return env

policy = make_policy(env_name)

collector = SyncDataCollector(
create_env_fn=env_fn,
create_env_kwargs={"seed": seed},
policy=policy,
frames_per_batch=200 * num_env,
max_frames_per_traj=2000,
total_frames=20000,
device="cpu",
reset_when_done=False,
)
for _, d in enumerate(collector): # noqa
break

assert (d["done"].sum(-2) >= 1).all()
assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1

del collector
# Deprecated reset_when_done
# @pytest.mark.parametrize("num_env", [1, 2])
# @pytest.mark.parametrize("env_name", ["vec"])
# def test_collector_done_persist(num_env, env_name, seed=5):
# if num_env == 1:
#
# def env_fn(seed):
# env = MockSerialEnv(device="cpu")
# env.set_seed(seed)
# return env
#
# else:
#
# def env_fn(seed):
# def make_env(seed):
# env = MockSerialEnv(device="cpu")
# env.set_seed(seed)
# return env
#
# env = ParallelEnv(
# num_workers=num_env,
# create_env_fn=make_env,
# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)],
# )
# env.set_seed(seed)
# return env
#
# policy = make_policy(env_name)
#
# collector = SyncDataCollector(
# create_env_fn=env_fn,
# create_env_kwargs={"seed": seed},
# policy=policy,
# frames_per_batch=200 * num_env,
# max_frames_per_traj=2000,
# total_frames=20000,
# device="cpu",
# reset_when_done=False,
# )
# for _, d in enumerate(collector): # noqa
# break
#
# assert (d["done"].sum(-2) >= 1).all()
# assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1
#
# del collector


@pytest.mark.parametrize("frames_per_batch", [200, 10])
Expand All @@ -424,7 +424,6 @@ def make_env(seed):
num_workers=num_env,
create_env_fn=make_env,
create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)],
allow_step_when_done=True,
)
env.set_seed(seed)
return env
Expand Down Expand Up @@ -1656,11 +1655,9 @@ def _step(
self.state += action
return TensorDict(
{
"next": {
"state": self.state.clone(),
"reward": self.reward_spec.zero(),
"done": self.done_spec.zero(),
}
"state": self.state.clone(),
"reward": self.reward_spec.zero(),
"done": self.done_spec.zero(),
},
self.batch_size,
)
Expand Down
16 changes: 12 additions & 4 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
import os.path
import re
from collections import defaultdict
from functools import partial

Expand Down Expand Up @@ -247,9 +248,16 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0):
else:
env = SerialEnv(3, envs)
env.set_seed(100)
# out = env._single_rollout(100, break_when_any_done=False)
out = env.rollout(100, break_when_any_done=False)
assert out.names[-1] == "time"
assert out.shape == torch.Size([3, 100])
assert (
out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19])
).all()
assert (
out[..., -1]["next", "step_count"].squeeze().cpu() == torch.tensor([20, 10, 20])
).all()
assert (
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])
).all()
Expand Down Expand Up @@ -322,7 +330,9 @@ def test_mb_env_batch_lock(self, device, seed=0):
td_expanded = td.unsqueeze(-1).expand(10, 2).reshape(-1).to_tensordict()
mb_env.step(td)

with pytest.raises(RuntimeError, match="Expected a tensordict with shape"):
with pytest.raises(
RuntimeError, match=re.escape("Expected a tensordict with shape==env.shape")
):
mb_env.step(td_expanded)

mb_env = DummyModelBasedEnvBase(
Expand Down Expand Up @@ -1576,9 +1586,7 @@ def test_batch_unlocked_with_batch_size(device):
td_expanded = td.expand(2, 2).reshape(-1).to_tensordict()
td = env.step(td)

with pytest.raises(
RuntimeError, match="Expected a tensordict with shape==env.shape, "
):
with pytest.raises(RuntimeError, match="Expected a tensordict with shape"):
env.step(td_expanded)


Expand Down
Loading

0 comments on commit e3b3879

Please sign in to comment.