diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index f66a560f2de..ba6f3790090 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1837,6 +1837,9 @@ def create_transformed_env(): assert (data.get(("next", "recurrent_state_c")) != 0.0).all() assert (data.get("recurrent_state_c") != 0.0).any() + @pytest.mark.skipif( + not _has_functorch, reason="vmap can only be used with functorch" + ) def test_lstm_vmap_complex_model(self): # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). @@ -1889,7 +1892,7 @@ def call(data, params): with params.to_module(training_model): return training_model(data) - assert torch.vmap(call, (None, 0))(data, params).shape == torch.Size( + assert vmap(call, (None, 0))(data, params).shape == torch.Size( (2, 50, 11) ) @@ -2163,6 +2166,9 @@ def create_transformed_env(): assert (data.get("recurrent_state") != 0.0).any() assert (data.get(("next", "recurrent_state")) != 0.0).all() + @pytest.mark.skipif( + not _has_functorch, reason="vmap can only be used with functorch" + ) def test_gru_vmap_complex_model(self): # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). @@ -2215,7 +2221,7 @@ def call(data, params): with params.to_module(training_model): return training_model(data) - assert torch.vmap(call, (None, 0))(data, params).shape == torch.Size( + assert vmap(call, (None, 0))(data, params).shape == torch.Size( (2, 50, 11) )