Skip to content

Commit

Permalink
[Minor] Remove ya gymnasium deprecation warning in vectorized envs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 27, 2023
1 parent 434fe58 commit 7f42576
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,22 @@ def _is_batched(self):
self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,)
)

@implement_for("gym", None, "0.27")
def _get_batch_size(self, env):
if hasattr(env, "num_envs"):
batch_size = torch.Size([env.num_envs, *self.batch_size])
else:
batch_size = self.batch_size
return batch_size

@implement_for("gymnasium", "0.27", None) # gymnasium wants the unwrapped env
def _get_batch_size(self, env): # noqa: F811
if hasattr(env, "num_envs"):
batch_size = torch.Size([env.unwrapped.num_envs, *self.batch_size])
else:
batch_size = self.batch_size
return batch_size

def _check_kwargs(self, kwargs: Dict):
if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
Expand Down

0 comments on commit 7f42576

Please sign in to comment.