diff --git a/test/test_specs.py b/test/test_specs.py index 25aec510c94..5cc2ed97226 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1907,6 +1907,22 @@ def test_unboundeddiscrete( assert spec == torch.stack(spec.unbind(0), 0) assert spec == torch.stack(spec.unbind(-1), -1) + def test_composite_encode_err(self): + c = CompositeSpec( + a=UnboundedContinuousTensorSpec( + 1, + ), + b=UnboundedContinuousTensorSpec( + 2, + ), + ) + with pytest.raises(KeyError, match="The CompositeSpec instance with keys"): + c.encode({"c": 0}) + with pytest.raises( + RuntimeError, match="raised a RuntimeError. Scroll up to know more" + ): + c.encode({"a": 0, "b": 0}) + @pytest.mark.parametrize( "device", diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a6ccf0dbeaf..42fc851d2a6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3979,6 +3979,10 @@ def encode( raise KeyError( f"The CompositeSpec instance with keys {self.keys()} does not have a '{key}' key." ) + except RuntimeError as err: + raise RuntimeError( + f"Encoding key {key} raised a RuntimeError. Scroll up to know more." + ) from err return out def __repr__(self) -> str: