From 711a4ee6446e000af020c001652ffe689482960c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 30 Apr 2024 18:02:09 +0100 Subject: [PATCH] [BugFix] Fix async gym when all reset (#2144) --- examples/envs/gym-async-info-reader.py | 2 ++ torchrl/collectors/collectors.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/envs/gym-async-info-reader.py b/examples/envs/gym-async-info-reader.py index 39f93131bdb..3f98e039290 100644 --- a/examples/envs/gym-async-info-reader.py +++ b/examples/envs/gym-async-info-reader.py @@ -15,6 +15,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--use_wrapper", action="store_true") + # Create the dummy environment class CustomEnv(gym.Env): def __init__(self, render_mode=None): @@ -71,6 +72,7 @@ def step(self, action): # Create an info reader: this object will read the info and write its content to the tensordict def reader(info, tensordict): return tensordict.set("field1", np.stack(info["field1"])) + env.set_info_dict_reader(info_dict_reader=reader) # Print the info readers (there should be 2: one to read the terminal states and another to read the 'field1') diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b17a0fbe736..6b8fbe0dcfe 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2249,7 +2249,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): See https://docs.python.org/3/library/multiprocessing.html for more info. - Examples: + Examples: >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c6583349948..037e3c8a483 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6996,7 +6996,7 @@ def _reset( ) # if not reset.any(), we don't need to do anything. # if reset.all(), we don't either (bc GymWrapper will call a plain reset). - if reset is not None and reset.any() and not reset.all(): + if reset is not None and reset.any(): saved_next = self._memo["saved_next"] # reset = reset.view(tensordict.shape) # we have a data container from the previous call to step