Skip to content

Commit

Permalink
[Feature] CatFrames for offline data (#1122)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 5, 2023
1 parent 0452133 commit 09f71b1
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 149 deletions.
22 changes: 11 additions & 11 deletions .circleci/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20

# With batched environments
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
env.num_envs=1 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
optim.device=cuda:0 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
total_frames=48 \
init_random_frames=10 \
Expand Down Expand Up @@ -86,17 +97,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
record_video=True \
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
env.num_envs=1 \
collector.total_frames=48 \
collector.frames_per_batch=16 \
collector.collector_device=cuda:0 \
optim.device=cuda:0 \
loss.mini_batch_size=10 \
loss.ppo_epochs=1 \
logger.backend= \
logger.log_interval=4 \
optim.lr_scheduler=False
python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \
total_frames=200 \
init_random_frames=10 \
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/nightly_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ jobs:
- name: Install test dependencies
run: |
python3 -mpip install numpy pytest --no-cache-dir
- name: Install tensordict
run: |
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
- name: Download built wheels
uses: actions/download-artifact@v2
with:
Expand Down Expand Up @@ -324,6 +327,9 @@ jobs:
shell: bash
run: |
python3 -mpip install numpy pytest --no-cache-dir
- name: Install tensordict
run: |
python3 -mpip install git+https://github.com/pytorch-labs/tensordict.git
- name: Download built wheels
uses: actions/download-artifact@v2
with:
Expand Down
9 changes: 0 additions & 9 deletions examples/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,7 @@ def make_transformed_env_pixels(base_env, env_cfg):
double_to_float_list += [
"reward",
]
double_to_float_list += [
"action",
]
double_to_float_inv_list += ["action"] # DMControl requires double-precision
double_to_float_list += ["observation_vector"]
else:
double_to_float_list += ["observation_vector"]
env.append_transform(
DoubleToFloat(
in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list
Expand Down Expand Up @@ -152,9 +146,6 @@ def make_transformed_env_states(base_env, env_cfg):
double_to_float_list += [
"reward",
]
double_to_float_list += [
"action",
]
double_to_float_inv_list += ["action"] # DMControl requires double-precision
double_to_float_list += ["observation_vector"]
else:
Expand Down
1 change: 0 additions & 1 deletion examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def make_env_transforms(
if env_library is DMControlEnv:
double_to_float_list += [
"reward",
"action",
]
float_to_double_list += ["action"] # DMControl requires double-precision
env.append_transform(
Expand Down
9 changes: 6 additions & 3 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ def main(cfg: "DictConfig"): # noqa: F821
# Main loop
r0 = None
l0 = None
frame_skip = cfg.env.frame_skip
mini_batch_size = cfg.loss.mini_batch_size
ppo_epochs = cfg.loss.ppo_epochs
for data in collector:

frames_in_batch = data.numel()
collected_frames += frames_in_batch * cfg.env.frame_skip
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())
data_view = data.reshape(-1)

Expand All @@ -93,8 +96,8 @@ def main(cfg: "DictConfig"): # noqa: F821
"reward_training", episode_rewards.mean().item(), collected_frames
)

for _ in range(cfg.loss.ppo_epochs):
for _ in range(frames_in_batch // cfg.loss.mini_batch_size):
for _ in range(ppo_epochs):
for _ in range(frames_in_batch // mini_batch_size):

# Get a data batch
batch = data_buffer.sample().to(model_device)
Expand Down
9 changes: 0 additions & 9 deletions examples/ppo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,7 @@ def make_transformed_env_pixels(base_env, env_cfg):
double_to_float_list += [
"reward",
]
double_to_float_list += [
"action",
]
double_to_float_inv_list += ["action"] # DMControl requires double-precision
double_to_float_list += ["observation_vector"]
else:
double_to_float_list += ["observation_vector"]
env.append_transform(
DoubleToFloat(
in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list
Expand Down Expand Up @@ -153,9 +147,6 @@ def make_transformed_env_states(base_env, env_cfg):
double_to_float_list += [
"reward",
]
double_to_float_list += [
"action",
]
double_to_float_inv_list += ["action"] # DMControl requires double-precision
double_to_float_list += ["observation_vector"]
else:
Expand Down
59 changes: 39 additions & 20 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,27 @@ def create_env_fn():
return GymEnv(env_name, frame_skip=frame_skip, device=device)

else:
if env_name == "ALE/Pong-v5":
if env_name == PONG_VERSIONED:

def create_env_fn():
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
return TransformedEnv(
GymEnv(env_name, frame_skip=frame_skip, device=device),
Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]),
base_env,
Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]),
)

else:

def create_env_fn():

base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
in_keys = list(base_env.observation_spec.keys(True, True))[:1]

return TransformedEnv(
GymEnv(env_name, frame_skip=frame_skip, device=device),
base_env,
Compose(
ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1),
ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1),
RewardClipping(0, 0.1),
),
)
Expand All @@ -179,8 +185,14 @@ def create_env_fn():
env_parallel = ParallelEnv(N, create_env_fn, create_env_kwargs=kwargs)
env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)

for key in env0.observation_spec.keys(True, True):
obs_key = key
break
else:
obs_key = None

if transformed_out:
t_out = get_transform_out(env_name, transformed_in)
t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key)

env0 = TransformedEnv(
env0,
Expand Down Expand Up @@ -223,7 +235,7 @@ def _make_multithreaded_env(

torch.manual_seed(0)
multithreaded_kwargs = (
{"frame_skip": frame_skip} if env_name == "ALE/Pong-v5" else {}
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED else {}
)
env_multithread = MultiThreadedEnv(
N,
Expand All @@ -233,46 +245,53 @@ def _make_multithreaded_env(
)

if transformed_out:
for key in env_multithread.observation_spec.keys(True, True):
obs_key = key
break
else:
obs_key = None
env_multithread = TransformedEnv(
env_multithread,
get_transform_out(env_name, transformed_in=False)(),
get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(),
)
return env_multithread


def get_transform_out(env_name, transformed_in):
def get_transform_out(env_name, transformed_in, obs_key=None):

if env_name == "ALE/Pong-v5":
if env_name == PONG_VERSIONED:
if obs_key is None:
obs_key = "pixels"

def t_out():
return (
Compose(*[ToTensorImage(), RewardClipping(0, 0.1)])
Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)])
if not transformed_in
else Compose(*[ObservationNorm(in_keys=["pixels"], loc=0, scale=1)])
else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)])
)

elif env_name == "CheetahRun-v1":
elif env_name == HALFCHEETAH_VERSIONED:
if obs_key is None:
obs_key = ("observation", "velocity")

def t_out():
return Compose(
ObservationNorm(
in_keys=[("observation", "velocity")], loc=0.5, scale=1.1
),
ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
RewardClipping(0, 0.1),
)

else:
if obs_key is None:
obs_key = "observation"

def t_out():
return (
Compose(
ObservationNorm(in_keys=["observation"], loc=0.5, scale=1.1),
ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
RewardClipping(0, 0.1),
)
if not transformed_in
else Compose(
ObservationNorm(in_keys=["observation"], loc=1.0, scale=1.0)
)
else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0))
)

return t_out
5 changes: 3 additions & 2 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_jumanji_consistency(self, envname, batch_size):
"Acrobot-v1",
CARTPOLE_VERSIONED,
]
ENVPOOL_ATARI_ENVS = [PONG_VERSIONED]
ENVPOOL_ATARI_ENVS = [] # PONG_VERSIONED]
ENVPOOL_GYM_ENVS = ENVPOOL_CLASSIC_CONTROL_ENVS + ENVPOOL_ATARI_ENVS
ENVPOOL_DM_ENVS = ["CheetahRun-v1"]
ENVPOOL_ALL_ENVS = ENVPOOL_GYM_ENVS + ENVPOOL_DM_ENVS
Expand Down Expand Up @@ -558,6 +558,7 @@ def test_specs(self, env_name, frame_skip, transformed_out, T=10, N=3):
def test_env_basic_operation(
self, env_name, frame_skip, transformed_out, T=10, N=3
):
torch.manual_seed(0)
env_multithreaded = _make_multithreaded_env(
env_name,
frame_skip,
Expand Down Expand Up @@ -737,7 +738,7 @@ def test_multithreaded_env_seed(

# Check that results are different if seed is different
# Skip Pong, since there different actions can lead to the same result
if env_name != "ALE/Pong-v5":
if env_name != PONG_VERSIONED:
env.set_seed(
seed=seed + 10,
)
Expand Down
16 changes: 7 additions & 9 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,14 +840,10 @@ def test_insert_transform():
def test_smoke_replay_buffer_transform(transform):
rb = ReplayBuffer(transform=transform(in_keys="observation"), batch_size=1)

# td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, [])
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
rb.add(td)
if not isinstance(rb._transform[0], (CatFrames,)):
rb.sample()
else:
with pytest.raises(NotImplementedError):
rb.sample()
return
rb.sample()

rb._transform = mock.MagicMock()
rb._transform.__len__ = lambda *args: 3
Expand All @@ -856,7 +852,7 @@ def test_smoke_replay_buffer_transform(transform):


transforms = [
partial(DiscreteActionProjection, num_actions_effective=1, max_actions=1),
partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3),
FiniteTensorDictCheck,
gSDENoise,
PinMemoryTransform,
Expand All @@ -865,13 +861,15 @@ def test_smoke_replay_buffer_transform(transform):

@pytest.mark.parametrize("transform", transforms)
def test_smoke_replay_buffer_transform_no_inkeys(transform):
if PinMemoryTransform is PinMemoryTransform and not torch.cuda.is_available():
if transform == PinMemoryTransform and not torch.cuda.is_available():
raise pytest.skip("No CUDA device detected, skipping PinMemory")
rb = ReplayBuffer(
collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1
)

td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, [])
action = torch.zeros(3)
action[..., 0] = 1
td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, [])
rb.add(td)
rb.sample()

Expand Down
5 changes: 3 additions & 2 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
except ImportError:
_has_tb = False

from _utils_internal import PONG_VERSIONED
from tensordict import TensorDict
from torchrl.data import (
LazyMemmapStorage,
Expand Down Expand Up @@ -836,7 +837,7 @@ def test_subsampler_state_dict(self):
class TestRecorder:
def _get_args(self):
args = Namespace()
args.env_name = "ALE/Pong-v5"
args.env_name = PONG_VERSIONED
args.env_task = ""
args.grayscale = True
args.env_library = "gym"
Expand Down Expand Up @@ -894,7 +895,7 @@ def test_recorder(self, N=8):
},
)
ea.Reload()
img = ea.Images("tmp_ALE/Pong-v5_video")
img = ea.Images(f"tmp_{PONG_VERSIONED}_video")
try:
assert len(img) == N // args.record_interval
break
Expand Down
Loading

0 comments on commit 09f71b1

Please sign in to comment.