Skip to content

Commit

Permalink
Added test for RewardRescale transform (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolas-dufour authored Jul 7, 2022
1 parent bd8a23d commit 0043d57
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ cython_debug/
# vim
*.swp
*.swo

# Vscode
.vscode
41 changes: 38 additions & 3 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from torch import Tensor
from torch import multiprocessing as mp
from torchrl import prod
from torchrl.data import NdBoundedTensorSpec, CompositeSpec
from torchrl.data import (
NdBoundedTensorSpec,
CompositeSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data import TensorDict
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs import GymEnv, ParallelEnv
Expand All @@ -26,6 +30,7 @@
DoubleToFloat,
CatTensors,
FlattenObservation,
RewardScaling,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.envs.transforms import VecNorm, TransformedEnv
Expand Down Expand Up @@ -872,9 +877,39 @@ def test_noop_reset_env(self, random, device, compose):
def test_binerized_reward(self, device):
pass

@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("scale", [0.1, 10])
@pytest.mark.parametrize("loc", [1, 5])
@pytest.mark.parametrize("keys", [None, ["reward_1"]])
@pytest.mark.parametrize("device", get_available_devices())
def test_reward_scaling(self, device):
pass
def test_reward_scaling(self, batch, scale, loc, keys, device):
torch.manual_seed(0)
if keys is None:
keys_total = set([])
else:
keys_total = set(keys)
reward_scaling = RewardScaling(keys_in=keys, scale=scale, loc=loc)
td = TensorDict(
{
**{key: torch.randn(*batch, 1, device=device) for key in keys_total},
"reward": torch.randn(*batch, 1, device=device),
},
batch,
)
td.set("dont touch", torch.randn(*batch, 1, device=device))
td_copy = td.clone()
reward_scaling(td)
for key in keys_total:
assert (td.get(key) == td_copy.get(key).mul_(scale).add_(loc)).all()
assert (td.get("dont touch") == td_copy.get("dont touch")).all()
if len(keys_total) == 0:
assert (
td.get("reward") == td_copy.get("reward").mul_(scale).add_(loc)
).all()
elif len(keys_total) == 1:
reward_spec = UnboundedContinuousTensorSpec(device=device)
reward_spec = reward_scaling.transform_reward_spec(reward_spec)
assert reward_spec.shape == torch.Size([1])

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda device found")
@pytest.mark.parametrize("device", get_available_devices())
Expand Down

0 comments on commit 0043d57

Please sign in to comment.