diff --git a/examples/envs/gym-async-info-reader.py b/examples/envs/gym-async-info-reader.py new file mode 100644 index 00000000000..1e9ef96fd07 --- /dev/null +++ b/examples/envs/gym-async-info-reader.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A toy example of executing a Gym environment asynchronously and gathering the info properly. +""" +import argparse + +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +parser = argparse.ArgumentParser() +parser.add_argument("--use_wrapper", action="store_true") + +# Create the dummy environment +class CustomEnv(gym.Env): + def __init__(self, render_mode=None): + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(3,)) + self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,)) + + def _get_info(self): + return {"field1": self.state**2} + + def _get_obs(self): + return self.state.copy() + + def reset(self, seed=None, options=None): + # We need the following line to seed self.np_random + super().reset(seed=seed) + self.state = np.zeros(self.observation_space.shape) + observation = self._get_obs() + info = self._get_info() + return observation, info + + def step(self, action): + self.state += action.item() + truncated = False + terminated = False + reward = 1 if terminated else 0 # Binary sparse rewards + observation = self._get_obs() + info = self._get_info() + return observation, reward, terminated, truncated, info + + +if __name__ == "__main__": + import torch + from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + from torchrl.envs import check_env_specs, GymEnv, GymWrapper + + args = parser.parse_args() + + num_envs = 10 + + if args.use_wrapper: + # Option 1: using GymWrapper + env = gym.vector.AsyncVectorEnv([lambda: CustomEnv() for _ in range(num_envs)]) + env = GymWrapper(env, device="cpu") + else: + # Option 2: using GymEnv directly, no need to call AsyncVectorEnv + gym.register("Custom-v0", CustomEnv) + env = GymEnv("Custom-v0", num_envs=num_envs) + + keys = ["field1"] + specs = [ + UnboundedContinuousTensorSpec(shape=(num_envs, 3), dtype=torch.float64), + ] + + # Create an info reader: this object will read the info and write its content to the tensordict + reader = lambda info, tensordict: tensordict.set("field1", np.stack(info["field1"])) + env.set_info_dict_reader(info_dict_reader=reader) + + # Print the info readers (there should be 2: one to read the terminal states and another to read the 'field1') + print("readers", env.info_dict_reader) + + # We need to unlock the specs to make them writable + env.observation_spec.unlock_() + env.observation_spec["field1"] = specs[0] + env.observation_spec.lock_() + + # Check that we did a good job + check_env_specs(env) + + td = env.reset() + print("reset data", td) + print("content of field1 (should be a 10x3 tensor)", td["field1"]) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 49604f7024e..6f8443e45b0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1597,12 +1597,18 @@ def __init__( if high is not None: raise TypeError(self.CONFLICTING_KWARGS.format("high", "maximum")) high = kwargs.pop("maximum") - warnings.warn("Maximum is deprecated since v0.4.0, using high instead.", category=DeprecationWarning) + warnings.warn( + "Maximum is deprecated since v0.4.0, using high instead.", + category=DeprecationWarning, + ) if "minimum" in kwargs: if low is not None: raise TypeError(self.CONFLICTING_KWARGS.format("low", "minimum")) low = kwargs.pop("minimum") - warnings.warn("Minimum is deprecated since v0.4.0, using low instead.", category=DeprecationWarning) + warnings.warn( + "Minimum is deprecated since v0.4.0, using low instead.", + category=DeprecationWarning, + ) domain = kwargs.pop("domain", "continuous") if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 9cbec79211d..6fe778a9a2c 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -162,8 +162,9 @@ class GymLikeEnv(_EnvWrapper): @classmethod def __new__(cls, *args, **kwargs): - cls._info_dict_reader = [] - return super().__new__(cls, *args, _batch_locked=True, **kwargs) + self = super().__new__(cls, *args, _batch_locked=True, **kwargs) + self._info_dict_reader = [] + return self def read_action(self, action): """Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment. diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index d8043cb9ef7..d225edfa79e 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -576,6 +576,7 @@ def __call__(cls, *args, **kwargs): ) add_info_dict = False if add_info_dict: + print("adding info dict reader", instance.observation_spec) instance.set_info_dict_reader( terminal_obs_reader(instance.observation_spec, backend=backend) ) @@ -1538,6 +1539,9 @@ def __call__(self, info_dict, tensordict): for i, obs in enumerate(terminal_obs): # writes final_obs inplace with terminal_obs content self._read_obs(obs, key[-1], final_obs_buffer, index=i) + print("self.batch_size", tensordict.shape) + print("final_obs_buffer", final_obs_buffer) + print("spec", item) tensordict.set(key, final_obs_buffer) return tensordict