diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 5811580826d..22f9835303b 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -94,6 +94,16 @@ def _vmas_to_torchrl_spec_transform( device=device, ) ) + elif isinstance(spec, gym_spaces.Dict): + spec_out = {} + for key in spec.keys(): + spec_out[key] = _vmas_to_torchrl_spec_transform( + spec[key], + device=device, + categorical_action_encoding=categorical_action_encoding, + ) + # the batch-size must be set later + return Composite(spec_out, device=device) else: raise NotImplementedError( f"spec of type {type(spec).__name__} is currently unaccounted for vmas"