Skip to content

Commit

Permalink
[Test] Fix tests for older pytorch versions (pytorch#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 17, 2023
1 parent c087963 commit fae0e03
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_errs(self):
probs=torch.tensor(()), mask=torch.tensor(()), indices=torch.tensor(())
)

@pytest.mark.parametrize("neg_inf", [-10, -float("inf")])
@pytest.mark.parametrize("neg_inf", [-float(10.0), -float("inf")])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("sparse", [True, False])
@pytest.mark.parametrize("logits", [True, False])
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_construction(self, neg_inf, sparse, logits, device):
else:
assert (dist.log_prob(torch.ones_like(sample)) > -float("inf")).all()

@pytest.mark.parametrize("neg_inf", [-10, -float("inf")])
@pytest.mark.parametrize("neg_inf", [-float(10.0), -float("inf")])
@pytest.mark.parametrize("sparse", [True, False])
@pytest.mark.parametrize("logits", [True, False])
def test_backprop(self, neg_inf, sparse, logits):
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_sample(self, neg_inf: float) -> None:
logits = torch.randn(4)
probs = F.softmax(logits, dim=-1)
mask = torch.tensor([True, False, True, True])
ref_probs = torch.where(mask, probs, 0.0)
ref_probs = probs.masked_fill(~mask, 0.0)
ref_probs /= ref_probs.sum(dim=-1, keepdim=True)

dist = MaskedCategorical(
Expand All @@ -331,7 +331,7 @@ def test_sample_sparse(self, neg_inf: float) -> None:
probs = F.softmax(logits, dim=-1)
mask = torch.tensor([True, False, True, True])
indices = torch.tensor([0, 2, 3])
ref_probs = torch.where(mask, probs, 0.0)
ref_probs = probs.masked_fill(~mask, 0.0)
ref_probs /= ref_probs.sum(dim=-1, keepdim=True)

dist = MaskedCategorical(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _mask_logits(
return logits

if not sparse_mask:
return torch.where(mask, logits, neg_inf)
return logits.masked_fill(~mask, neg_inf)

if padding_value is not None:
padding_mask = mask == padding_value
Expand Down

0 comments on commit fae0e03

Please sign in to comment.