Skip to content

Commit

Permalink
[BugFix, Test] Fix torch.vmap call in RNN tests (pytorch#1749)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 15, 2023
1 parent b3d2aa6 commit 08f0bed
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,9 @@ def create_transformed_env():
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
assert (data.get("recurrent_state_c") != 0.0).any()

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
)
def test_lstm_vmap_complex_model(self):
# Tests that all ops in GRU are compatible with VMAP (when build using
# the PT backend).
Expand Down Expand Up @@ -1889,7 +1892,7 @@ def call(data, params):
with params.to_module(training_model):
return training_model(data)

assert torch.vmap(call, (None, 0))(data, params).shape == torch.Size(
assert vmap(call, (None, 0))(data, params).shape == torch.Size(
(2, 50, 11)
)

Expand Down Expand Up @@ -2163,6 +2166,9 @@ def create_transformed_env():
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
)
def test_gru_vmap_complex_model(self):
# Tests that all ops in GRU are compatible with VMAP (when build using
# the PT backend).
Expand Down Expand Up @@ -2215,7 +2221,7 @@ def call(data, params):
with params.to_module(training_model):
return training_model(data)

assert torch.vmap(call, (None, 0))(data, params).shape == torch.Size(
assert vmap(call, (None, 0))(data, params).shape == torch.Size(
(2, 50, 11)
)

Expand Down

0 comments on commit 08f0bed

Please sign in to comment.