From 69d44f5cf4bf84eab0f21b0eea98112651f7f9a1 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Mon, 12 Feb 2024 11:37:49 +0100 Subject: [PATCH] [Feature] Replace RewardClipping with SignTransform in Atari examples (#1870) --- examples/a2c/utils_atari.py | 4 ++-- examples/dqn/utils_atari.py | 4 ++-- examples/impala/utils.py | 4 ++-- examples/ppo/utils_atari.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/a2c/utils_atari.py b/examples/a2c/utils_atari.py index 89a51f7e64b..0ddcd79123e 100644 --- a/examples/a2c/utils_atari.py +++ b/examples/a2c/utils_atari.py @@ -20,8 +20,8 @@ NoopResetEnv, ParallelEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -73,7 +73,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) if not is_test: - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env diff --git a/examples/dqn/utils_atari.py b/examples/dqn/utils_atari.py index 24b6509147c..b9805659e63 100644 --- a/examples/dqn/utils_atari.py +++ b/examples/dqn/utils_atari.py @@ -14,8 +14,8 @@ GymEnv, NoopResetEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -42,7 +42,7 @@ def make_env(env_name, frame_skip, device, is_test=False): env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: env.append_transform(EndOfLifeTransform()) - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(ToTensorImage()) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 2983f8a0193..b365dca3867 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -16,8 +16,8 @@ GymEnv, NoopResetEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -46,7 +46,7 @@ def make_env(env_name, device, is_test=False): env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: env.append_transform(EndOfLifeTransform()) - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(ToTensorImage(from_int=False)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py index eaef640ebb0..5cb838cac47 100644 --- a/examples/ppo/utils_atari.py +++ b/examples/ppo/utils_atari.py @@ -19,8 +19,8 @@ NoopResetEnv, ParallelEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -71,7 +71,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) if not is_test: - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env