Skip to content

Commit

Permalink
[BugFix] fix trunc normal device (pytorch#1931)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent 4fd0343 commit 78b31a9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 34 deletions.
28 changes: 10 additions & 18 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,38 +140,30 @@ def test_tanhnormal(self, min, max, vecs, upscale, shape, device):
class TestTruncatedNormal:
def test_truncnormal(self, min, max, vecs, upscale, shape, device):
torch.manual_seed(0)
min, max, vecs, upscale, shape = _map_all(
min, max, vecs, upscale, shape, device=device
*vecs, min, max, vecs, upscale = torch.utils._pytree.tree_map(
lambda t: torch.as_tensor(t, device=device),
(*vecs, min, max, vecs, upscale),
)
assert all(t.device == device for t in vecs)
d = TruncatedNormal(
*vecs,
upscale=upscale,
min=min,
max=max,
)
assert d.device == device
for _ in range(100):
a = d.rsample(shape)
assert a.device == device
assert a.shape[: len(shape)] == shape
assert (a >= d.min).all()
assert (a <= d.max).all()
lp = d.log_prob(a)
assert torch.isfinite(lp).all()
assert not torch.isfinite(
d.log_prob(
torch.as_tensor(d.min, device=device).expand(
(*d.batch_shape, *d.event_shape)
)
- 1e-2
)
).any()
assert not torch.isfinite(
d.log_prob(
torch.as_tensor(d.max, device=device).expand(
(*d.batch_shape, *d.event_shape)
)
+ 1e-2
)
).any()
oob_min = d.min.expand((*d.batch_shape, *d.event_shape)) - 1e-2
assert not torch.isfinite(d.log_prob(oob_min)).any()
oob_max = d.max.expand((*d.batch_shape, *d.event_shape)) + 1e-2
assert not torch.isfinite(d.log_prob(oob_max)).any()

def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device):
torch.manual_seed(0)
Expand Down
24 changes: 10 additions & 14 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class TruncatedNormal(D.Independent):

num_params: int = 2

base_dist: _TruncatedNormal

arg_constraints = {
"loc": constraints.real,
"scale": constraints.greater_than(1e-6),
Expand Down Expand Up @@ -231,20 +233,10 @@ def __init__(
self.tanh_loc = tanh_loc

self.device = loc.device
self.upscale = (
upscale
if not isinstance(upscale, torch.Tensor)
else upscale.to(self.device)
)
self.upscale = torch.as_tensor(upscale, device=self.device)

if isinstance(max, torch.Tensor):
max = max.to(self.device)
else:
max = torch.as_tensor(max, device=self.device)
if isinstance(min, torch.Tensor):
min = min.to(self.device)
else:
min = torch.as_tensor(min, device=self.device)
max = torch.as_tensor(max, device=self.device)
min = torch.as_tensor(min, device=self.device)
self.min = min
self.max = max
self.update(loc, scale)
Expand All @@ -258,7 +250,11 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.scale = scale

base_dist = _TruncatedNormal(
loc, scale, self.min.expand_as(loc), self.max.expand_as(scale)
loc,
scale,
self.min.expand_as(loc),
self.max.expand_as(scale),
device=self.device,
)
super().__init__(base_dist, 1, validate_args=False)

Expand Down
8 changes: 6 additions & 2 deletions torchrl/modules/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ class TruncatedStandardNormal(Distribution):
has_rsample = True
eps = 1e-6

def __init__(self, a, b, validate_args=None):
def __init__(self, a, b, validate_args=None, device=None):
self.a, self.b = broadcast_all(a, b)
self.a = self.a.to(device)
self.b = self.b.to(device)
if isinstance(a, Number) and isinstance(b, Number):
batch_shape = torch.Size()
else:
Expand Down Expand Up @@ -139,9 +141,11 @@ class TruncatedNormal(TruncatedStandardNormal):

has_rsample = True

def __init__(self, loc, scale, a, b, validate_args=None):
def __init__(self, loc, scale, a, b, validate_args=None, device=None):
scale = scale.clamp_min(self.eps)
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
a = a.to(device)
b = b.to(device)
self._non_std_a = a
self._non_std_b = b
a = (a - self.loc) / self.scale
Expand Down

0 comments on commit 78b31a9

Please sign in to comment.