Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix CompositeSpec.to_numpy method #931

Merged
merged 7 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ def test_is_in(self, is_complete, device, dtype):
r = ts.rand()
assert ts.is_in(r)

def test_to_numpy(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
for _ in range(100):
r = ts.rand()
for key, value in ts.to_numpy(r).items():
spec = ts[key]
assert (spec.to_numpy(r[key]) == value).all()

@pytest.mark.parametrize("shape", [[], [3]])
def test_project(self, is_complete, device, dtype, shape):
ts = self._composite_spec(is_complete, device, dtype)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,7 +1947,7 @@ def clone(self) -> CompositeSpec:
)

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
return {key: self[key]._to_numpy(val) for key, val in val.items()}
return {key: self[key].to_numpy(val) for key, val in val.items()}

def zero(self, shape=None) -> TensorDictBase:
if shape is None:
Expand Down
57 changes: 11 additions & 46 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@
# Libraries necessary for MultiThreadedEnv
import envpool

try:
import gym
except ModuleNotFoundError:
import gymnasium as gym

import treevalue

_has_envpool = True
Expand Down Expand Up @@ -1181,66 +1176,36 @@ def _get_input_spec(self) -> TensorSpec:
# DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones.

# Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
action_spec = self._add_shape_to_spec(self._env.spec.action_space)
transformed_spec = _gym_to_torchrl_spec_transform(
action_spec,
action_spec = _gym_to_torchrl_spec_transform(
self._env.spec.action_space,
device=self.device,
categorical_action_encoding=True,
)
if not transformed_spec.shape:
transformed_spec.shape = (self.num_workers,)
action_spec = self._add_shape_to_spec(action_spec)
return CompositeSpec(
action=transformed_spec,
shape=transformed_spec.shape,
action=action_spec,
shape=(self.num_workers,),
)

def _get_observation_spec(self) -> TensorSpec:
# Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
obs_spec = self._add_shape_to_spec(self._env.spec.observation_space)
observation_spec = _gym_to_torchrl_spec_transform(
obs_spec,
self._env.spec.observation_space,
device=self.device,
categorical_action_encoding=True,
)
if isinstance(observation_spec, CompositeSpec):
observation_spec.shape = (self.num_workers,)
observation_spec = self._add_shape_to_spec(observation_spec)
return CompositeSpec(
observation=observation_spec,
shape=observation_spec.shape,
shape=(self.num_workers,),
)

def _add_shape_to_spec(
self, spec: gym.spaces.space.Space
) -> gym.spaces.space.Space:
if isinstance(spec, gym.spaces.Box):
return gym.spaces.Box(
low=np.stack([spec.low] * self.num_workers),
high=np.stack([spec.high] * self.num_workers),
dtype=spec.dtype,
shape=(self.num_workers, *spec.shape),
)
if isinstance(spec, gym.spaces.dict.Dict):
spec_dict = {}
for key in spec.keys():
if isinstance(spec[key], gym.spaces.Box):
spec_dict[key] = gym.spaces.Box(
low=np.stack([spec[key].low] * self.num_workers),
high=np.stack([spec[key].high] * self.num_workers),
dtype=spec[key].dtype,
shape=(self.num_workers, *spec[key].shape),
)
elif isinstance(spec[key], gym.spaces.dict.Dict):
# If needed, we could add support by applying this function recursively
raise TypeError("Nested specs with depth > 1 are not supported.")
return spec_dict
if isinstance(spec, gym.spaces.discrete.Discrete):
# Discrete spec in Gym doesn't have shape, so nothing to change
return spec
raise TypeError(f"Unsupported spec type {spec.__class__}.")
def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec:
return spec.expand((self.num_workers, *spec.shape))

def _get_reward_spec(self) -> TensorSpec:
return UnboundedContinuousTensorSpec(
device=self.device, shape=(self.num_workers,)
device=self.device, shape=(self.num_workers, 1)
)

def __repr__(self) -> str:
Expand Down