Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TorchRL2Gym conversion #1795

Merged
merged 44 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c4c46f9
init
vmoens Jan 13, 2024
1204c80
amend
vmoens Jan 14, 2024
f1dae58
amend
vmoens Jan 15, 2024
71dfaf6
amend
vmoens Jan 15, 2024
d04e7b6
Merge remote-tracking branch 'origin/main' into torchrl-to-gym-env
vmoens Jan 17, 2024
5b9eade
Merge remote-tracking branch 'origin/main' into torchrl-to-gym-env
vmoens Jan 17, 2024
0a0184a
amend
vmoens Jan 17, 2024
fe60daf
amend
vmoens Jan 17, 2024
ec53047
fix
vmoens Jan 17, 2024
a197774
amend
vmoens Jan 17, 2024
8a5eba5
Merge remote-tracking branch 'origin/main' into torchrl-to-gym-env
vmoens Jan 17, 2024
173e60b
amend
vmoens Jan 17, 2024
7a0a0f6
amend
vmoens Jan 17, 2024
c99802c
amend
vmoens Jan 17, 2024
f04d51c
amend
vmoens Jan 17, 2024
8047f08
empty
vmoens Jan 17, 2024
9480fa0
amend
vmoens Jan 18, 2024
4751806
amend
vmoens Jan 18, 2024
6ba8d40
lint
vmoens Jan 18, 2024
8fa951a
amend
vmoens Jan 18, 2024
ba1a973
amend
vmoens Jan 18, 2024
8890328
amend
vmoens Jan 18, 2024
2df683c
amend
vmoens Jan 18, 2024
6ef6753
amend
vmoens Jan 18, 2024
d10fc4e
lint
vmoens Jan 18, 2024
3f984d1
lint
vmoens Jan 18, 2024
403a8ab
lint
vmoens Jan 18, 2024
2df6001
amend
vmoens Jan 18, 2024
1f94d52
amend
vmoens Jan 18, 2024
3f6cbac
amend
vmoens Jan 18, 2024
6cb3d37
amend
vmoens Jan 18, 2024
6e13317
Merge remote-tracking branch 'origin/main' into torchrl-to-gym-env
vmoens Jan 18, 2024
07ab06c
amend
vmoens Jan 18, 2024
6652aff
amend
vmoens Jan 19, 2024
5f4ff19
amend
vmoens Jan 19, 2024
7e91b1a
amend
vmoens Jan 19, 2024
2e7b886
Merge remote-tracking branch 'origin/main' into torchrl-to-gym-env
vmoens Jan 19, 2024
6084e7b
amend
vmoens Jan 19, 2024
b033acc
amend
vmoens Jan 19, 2024
5a87498
amend
vmoens Jan 19, 2024
6b3c323
amend
vmoens Jan 19, 2024
ec3ed57
amend
vmoens Jan 19, 2024
95e2890
amend
vmoens Jan 19, 2024
62dcee6
amend
vmoens Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
vmoens committed Jan 13, 2024
commit c4c46f90b160fbea3f2015cea2329e0dc683d81e
12 changes: 12 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,14 @@ def __init__(
shape, ContinuousBox(low, high, device=device), device, dtype, "continuous"
)

@property
def low(self):
return self.space.low

@property
def high(self):
return self.space.high

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
Expand Down Expand Up @@ -2373,6 +2381,10 @@ def __init__(
super().__init__(shape, space, device, dtype, domain="discrete")
self.update_mask(mask)

@property
def n(self):
return self.space.n

def update_mask(self, mask):
if mask is not None:
try:
Expand Down
120 changes: 120 additions & 0 deletions torchrl/envs/libs/_gym_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from torch.utils._pytree import tree_map

from torchrl._utils import implement_for
from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform


class TorchRLGymWrapper(gymnasium.Env):
def __init__(self, env_cls, to_numpy=False, **kwargs):
self.torchrl_env = env_cls(**kwargs)
super().__init__()
self.action_space = _torchrl_to_gym_spec_transform(
self.torchrl_env.action_spec,
categorical_action_encoding=self.torchrl_env.categorical_action_encoding,
)
self.observation_space = _torchrl_to_gym_spec_transform(
self.torchrl_env.observation_spec,
categorical_action_encoding=self.torchrl_env.categorical_action_encoding,
)
self.to_numpy = to_numpy

@implement_for("gymnasium")
def step(self, action):
self._tensordict.set("action", action)
self.torchrl_env.step(self._tensordict)
_tensordict = step_mdp(self._tensordict)
keys = list(self.torchrl_env.observation_spec.keys())
observation = (
self._tensordict.get("next")
.select(*self.torchrl_env.observation_spec.keys())
.to_dict()
)
reward = self._tensordict.get(("next", "reward"))
terminated = self._tensordict.get(("next", "terminated"))
truncated = self._tensordict.get(("next", "truncated"))
info = {}
self._tensordict = _tensordict
out = (observation, reward, terminated, truncated, info)
if self.to_numpy():
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gym", "0.26", None)
def step(self, action):
self._tensordict.set("action", action)
self.torchrl_env.step(self._tensordict)
_tensordict = step_mdp(self._tensordict)
keys = list(self.torchrl_env.observation_spec.keys())
observation = (
self._tensordict.get("next")
.select(*self.torchrl_env.observation_spec.keys())
.to_dict()
)
reward = self._tensordict.get(("next", "reward"))
terminated = self._tensordict.get(("next", "terminated"))
truncated = self._tensordict.get(("next", "truncated"))
info = {}
self._tensordict = _tensordict
out = (observation, reward, terminated, truncated, info)
if self.to_numpy():
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gym", None, "0.26")
def step(self, action):
self._tensordict.set("action", action)
self.torchrl_env.step(self._tensordict)
_tensordict = step_mdp(self._tensordict)
keys = list(self.torchrl_env.observation_spec.keys())
observation = (
self._tensordict.get("next")
.select(*self.torchrl_env.observation_spec.keys())
.to_dict()
)
reward = self._tensordict.get(("next", "reward"))
done = self._tensordict.get(("next", "done"))
info = {}
self._tensordict = _tensordict
out = (observation, reward, done, info)
if self.to_numpy():
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gymnasium")
def reset(self):
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict.select(
*self.torchrl_env.observation_spec.keys()
).to_dict()
out = observation, {}
if self.to_numpy():
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gym", None, "0.26")
def reset(self):
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict.select(
*self.torchrl_env.observation_spec.keys()
).to_dict()
out = observation
if self.to_numpy():
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gym", "0.26", None)
def reset(self):
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict.select(
*self.torchrl_env.observation_spec.keys()
).to_dict()
out = observation, {}
if self.to_numpy():
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out
59 changes: 58 additions & 1 deletion torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict
from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict
from torchrl.envs.batched_envs import CloudpickleWrapper
from torchrl.envs.common import _EnvPostInit

Expand Down Expand Up @@ -332,6 +332,63 @@ def _gym_to_torchrl_spec_transform(
)


def _torchrl_to_gym_spec_transform(
spec,
batch_size=torch.Size([]),
categorical_action_encoding=False,
remap_state_to_observation: bool = True,
) -> TensorSpec:
"""Maps TorchRL specs to gym spaces.

Args:
spec: the torchrl spec to transform.
batch_size (torch.Size): batch-size of the input specs.
categorical_action_encoding: whether discrete spaces should be mapped to categorical or one-hot.
Defaults to one-hot.
remap_state_to_observation: whether to rename the 'state' key of Dict specs to "observation". Default is true.

"""
gym_spaces = gym_backend("spaces")
shape = spec.shape[len(batch_size) :]
if isinstance(spec, MultiDiscreteTensorSpec):
return gym_spaces.multi_discrete.MultiDiscrete(
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)
if isinstance(spec, MultiOneHotDiscreteTensorSpec):
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
if isinstance(spec, DiscreteTensorSpec):
return gym_spaces.discrete.Discrete(
spec.n
) # dtype=torch_to_numpy_dtype_dict[spec.dtype])
if isinstance(spec, OneHotDiscreteTensorSpec):
return gym_spaces.discrete.Discrete(spec.n)
if isinstance(spec, BinaryDiscreteTensorSpec):
return gym_spaces.multi_binary.MultiBinary(
spec.n, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)
if isinstance(spec, UnboundedContinuousTensorSpec):
return gym_spaces.Box(-float("inf"), float("inf"), shape)
if isinstance(spec, BoundedTensorSpec):
return gym_spaces.Box(
spec.low.detach().cpu().numpy(), spec.high.detach().cpu().numpy(), shape
)
if isinstance(spec, CompositeSpec):
return dict(
**{
key: _torchrl_to_gym_spec_transform(
val,
batch_size=spec.shape,
categorical_action_encoding=categorical_action_encoding,
)
for key, val in spec.items()
}
)
else:
raise NotImplementedError(
f"spec of type {type(spec).__name__} is currently unaccounted for"
)


def _get_envs(to_dict=False) -> List:
if not _has_gym:
raise ImportError("Gym(nasium) could not be found in your virtual environment.")
Expand Down
Loading