Skip to content

Commit

Permalink
[BugFix] Fix tanh normal mode (pytorch#2198)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 24, 2024
1 parent 1062e3e commit 00b7c2e
Show file tree
Hide file tree
Showing 20 changed files with 383 additions and 129 deletions.
8 changes: 4 additions & 4 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
out_keys=[("data", "action")],
distribution_class=TanhNormal,
distribution_kwargs={
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand All @@ -153,8 +153,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions
out_keys=[("data", "action")],
distribution_class=TanhNormal,
distribution_kwargs={
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand Down
27 changes: 23 additions & 4 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13519,17 +13519,36 @@ def __init__(self):

def test_loss_exploration():
class DummyLoss(LossModule):
def forward(self, td):
assert exploration_type() == InteractionType.MODE
def forward(self, td, mode):
if mode is None:
mode = self.deterministic_sampling_mode
assert exploration_type() == mode
with set_exploration_type(ExplorationType.RANDOM):
assert exploration_type() == ExplorationType.RANDOM
assert exploration_type() == ExplorationType.MODE
assert exploration_type() == mode
return td

loss_fn = DummyLoss()
with set_exploration_type(ExplorationType.RANDOM):
assert exploration_type() == ExplorationType.RANDOM
loss_fn(None)
loss_fn(None, None)
assert exploration_type() == ExplorationType.RANDOM

with set_exploration_type(ExplorationType.RANDOM):
assert exploration_type() == ExplorationType.RANDOM
loss_fn(None, ExplorationType.DETERMINISTIC)
assert exploration_type() == ExplorationType.RANDOM

loss_fn.deterministic_sampling_mode = ExplorationType.MODE
with set_exploration_type(ExplorationType.RANDOM):
assert exploration_type() == ExplorationType.RANDOM
loss_fn(None, ExplorationType.MODE)
assert exploration_type() == ExplorationType.RANDOM

loss_fn.deterministic_sampling_mode = ExplorationType.MEAN
with set_exploration_type(ExplorationType.RANDOM):
assert exploration_type() == ExplorationType.RANDOM
loss_fn(None, ExplorationType.MEAN)
assert exploration_type() == ExplorationType.RANDOM


Expand Down
65 changes: 52 additions & 13 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def _map_all(*tensors_or_other, device):

class TestTanhNormal:
@pytest.mark.parametrize(
"min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1]
"low", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1]
)
@pytest.mark.parametrize(
"max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1]
"high", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1]
)
@pytest.mark.parametrize(
"vecs",
Expand All @@ -102,25 +102,64 @@ class TestTanhNormal:
)
@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])])
@pytest.mark.parametrize("device", get_default_devices())
def test_tanhnormal(self, min, max, vecs, upscale, shape, device):
min, max, vecs, upscale, shape = _map_all(
min, max, vecs, upscale, shape, device=device
def test_tanhnormal(self, low, high, vecs, upscale, shape, device):
torch.manual_seed(0)
low, high, vecs, upscale, shape = _map_all(
low, high, vecs, upscale, shape, device=device
)
torch.manual_seed(0)
d = TanhNormal(
*vecs,
upscale=upscale,
min=min,
max=max,
low=low,
high=high,
)
for _ in range(100):
a = d.rsample(shape)
assert a.shape[: len(shape)] == shape
assert (a >= d.min).all()
assert (a <= d.max).all()
assert (a >= d.low).all()
assert (a <= d.high).all()
lp = d.log_prob(a)
assert torch.isfinite(lp).all()

def test_tanhnormal_mode(self):
# Checks that the std of the mode computed by tanh normal is within a certain range
# when starting from close points

torch.manual_seed(0)
# 10 start points with 1000 jitters around that
# std of the loc is about 1e-4
loc = torch.randn(10) + torch.randn(1000, 10) / 10000

t = TanhNormal(loc=loc, scale=0.5, low=-1, high=1, event_dims=0)

mode = t.get_mode()
assert mode.shape == loc.shape
empirical_mode, empirical_mode_lp = torch.zeros_like(loc), -float("inf")
for v in torch.arange(-1, 1, step=0.01):
lp = t.log_prob(v.expand_as(t.loc))
empirical_mode = torch.where(lp > empirical_mode_lp, v, empirical_mode)
empirical_mode_lp = torch.where(
lp > empirical_mode_lp, lp, empirical_mode_lp
)
assert abs(empirical_mode - mode).max() < 0.1, abs(empirical_mode - mode).max()
assert mode.shape == loc.shape
assert (mode.std(0).max() < 0.1).all(), mode.std(0)

@pytest.mark.parametrize("event_dims", [0, 1, 2])
def test_tanhnormal_event_dims(self, event_dims):
scale = 1
loc = torch.randn(1, 2, 3, 4)
t = TanhNormal(loc=loc, scale=scale, event_dims=event_dims)
sample = t.sample()
assert sample.shape == loc.shape
exp_shape = loc.shape[:-event_dims] if event_dims > 0 else loc.shape
assert t.log_prob(sample).shape == exp_shape, (
t.log_prob(sample).shape,
event_dims,
exp_shape,
)


class TestTruncatedNormal:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -159,13 +198,13 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device):
a = d.rsample(shape)
assert a.device == device
assert a.shape[: len(shape)] == shape
assert (a >= d.min).all()
assert (a <= d.max).all()
assert (a >= d.low).all()
assert (a <= d.high).all()
lp = d.log_prob(a)
assert torch.isfinite(lp).all()
oob_min = d.min.expand((*d.batch_shape, *d.event_shape)) - 1e-2
oob_min = d.low.expand((*d.batch_shape, *d.event_shape)) - 1e-2
assert not torch.isfinite(d.log_prob(oob_min)).any()
oob_max = d.max.expand((*d.batch_shape, *d.event_shape)) + 1e-2
oob_max = d.high.expand((*d.batch_shape, *d.event_shape)) + 1e-2
assert not torch.isfinite(d.log_prob(oob_max)).any()

@pytest.mark.skipif(not _has_scipy, reason="scipy not installed")
Expand Down
2 changes: 1 addition & 1 deletion test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def test_gsde(
wrapper = NormalParamWrapper(model)
module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
distribution_class = TanhNormal
distribution_kwargs = {"min": -bound, "max": bound}
distribution_kwargs = {"low": -bound, "high": bound}
spec = BoundedTensorSpec(
-torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,)
).to(device)
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,8 @@ def test_dt_inference_wrapper(self, online):
)
dist_class = TanhDelta
dist_kwargs = {
"min": -1.0,
"max": 1.0,
"low": -1.0,
"high": 1.0,
}
actor = ProbabilisticActor(
in_keys=in_keys,
Expand Down
18 changes: 9 additions & 9 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ class SyncDataCollector(DataCollectorBase):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
return_same_td (bool, optional): if ``True``, the same TensorDict
will be returned at each iteration, with its values
updated. This feature should be used cautiously: if the same
Expand Down Expand Up @@ -1336,9 +1336,9 @@ class _MultiDataCollector(DataCollectorBase):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
Expand Down Expand Up @@ -2635,9 +2635,9 @@ class aSyncDataCollector(MultiaSyncDataCollector):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
Expand Down
6 changes: 3 additions & 3 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ class DistributedDataCollector(DataCollectorBase):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
Expand Down
6 changes: 3 additions & 3 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ class RayCollector(DataCollectorBase):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (Python class): a collector class to be remotely instantiated. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
Expand Down
5 changes: 3 additions & 2 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,9 @@ class RPCDataCollector(DataCollectorBase):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collector_class (type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
Expand Down
6 changes: 3 additions & 3 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ class DistributedSyncDataCollector(DataCollectorBase):
information.
Defaults to ``False``.
exploration_type (ExplorationType, optional): interaction mode to be used when
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``,
``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
Expand Down
Loading

0 comments on commit 00b7c2e

Please sign in to comment.