Skip to content

Commit

Permalink
[BugFix] Fix env.shape regex matches (pytorch#1940)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent ca42794 commit 93885b6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_mb_env_batch_lock(self, device, seed=0):
mb_env.step(td)

with pytest.raises(
RuntimeError, match=re.escape("Expected a tensordict with shape==env.shape")
RuntimeError, match=re.escape("Expected a tensordict with shape==env.batch_size")
):
mb_env.step(td_expanded)

Expand Down Expand Up @@ -1615,7 +1615,7 @@ def test_batch_locked(device):
_ = env.step(td)

with pytest.raises(
RuntimeError, match="Expected a tensordict with shape==env.shape, "
RuntimeError, match="Expected a tensordict with shape==env.batch_size, "
):
env.step(td_expanded)

Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8231,7 +8231,7 @@ def test_batch_locked_transformed(device):
env.step(td)

with pytest.raises(
RuntimeError, match="Expected a tensordict with shape==env.shape, "
RuntimeError, match="Expected a tensordict with shape==env.batch_size, "
):
env.step(td_expanded)

Expand Down Expand Up @@ -8275,7 +8275,7 @@ def test_batch_unlocked_with_batch_size_transformed(device):
td_expanded = td.expand(2, 2).reshape(-1).to_tensordict()

with pytest.raises(
RuntimeError, match="Expected a tensordict with shape==env.shape, "
RuntimeError, match="Expected a tensordict with shape==env.batch_size, "
):
env.step(td_expanded)

Expand Down

0 comments on commit 93885b6

Please sign in to comment.