Skip to content

Commit

Permalink
DMControl: bugfix in env device (pytorch#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 25, 2022
1 parent 5c51f15 commit 476ca4e
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 17 deletions.
6 changes: 4 additions & 2 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode

try:
Expand Down Expand Up @@ -177,8 +177,10 @@ def main(args):
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
elif isinstance(create_env_fn, EnvCreator):
recorder_rm.load_state_dict(create_env_fn().state_dict())
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling
for t in recorder.transform:
Expand Down
6 changes: 4 additions & 2 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode

try:
Expand Down Expand Up @@ -159,8 +159,10 @@ def main(args):
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
elif isinstance(create_env_fn, EnvCreator):
recorder_rm.load_state_dict(create_env_fn().state_dict())
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())
# reset reward scaling
for t in recorder.transform:
if isinstance(t, RewardScaling):
Expand Down
6 changes: 4 additions & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode

try:
Expand Down Expand Up @@ -147,8 +147,10 @@ def main(args):
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
elif isinstance(create_env_fn, EnvCreator):
recorder_rm.load_state_dict(create_env_fn().state_dict())
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling
for t in recorder.transform:
Expand Down
6 changes: 4 additions & 2 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode

try:
Expand Down Expand Up @@ -177,8 +177,10 @@ def main(args):
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
elif isinstance(create_env_fn, EnvCreator):
recorder_rm.load_state_dict(create_env_fn().state_dict())
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling
for t in recorder.transform:
Expand Down
6 changes: 4 additions & 2 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime

from torchrl.envs import ParallelEnv
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.utils import set_exploration_mode

try:
Expand Down Expand Up @@ -175,8 +175,10 @@ def main(args):
if isinstance(create_env_fn, ParallelEnv):
recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"])
create_env_fn.close()
else:
elif isinstance(create_env_fn, EnvCreator):
recorder_rm.load_state_dict(create_env_fn().state_dict())
else:
recorder_rm.load_state_dict(create_env_fn.state_dict())

# reset reward scaling, as it was just overwritten by state_dict load
for t in recorder.transform:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def policy(td):
if not self.is_done:
for i in range(n_steps):
if auto_cast_to_device:
tensordicts = auto_cast_to_device.to(policy_device)
tensordict = tensordict.to(policy_device)
tensordict = policy(tensordict)
if auto_cast_to_device:
tensordict = tensordict.to(env_device)
Expand Down
41 changes: 35 additions & 6 deletions torchrl/envs/libs/dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@


def _dmcontrol_to_torchrl_spec_transform(
spec, dtype: Optional[torch.dtype] = None
spec,
dtype: Optional[torch.dtype] = None,
device: DEVICE_TYPING = None,
) -> TensorSpec:
if isinstance(spec, collections.OrderedDict):
spec = {
"next_" + k: _dmcontrol_to_torchrl_spec_transform(item)
"next_" + k: _dmcontrol_to_torchrl_spec_transform(item, device=device)
for k, item in spec.items()
}
return CompositeSpec(**spec)
Expand All @@ -54,11 +56,14 @@ def _dmcontrol_to_torchrl_spec_transform(
minimum=spec.minimum,
maximum=spec.maximum,
dtype=dtype,
device=device,
)
elif isinstance(spec, dm_env.specs.Array):
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
return NdUnboundedContinuousTensorSpec(shape=spec.shape, dtype=dtype)
return NdUnboundedContinuousTensorSpec(
shape=spec.shape, dtype=dtype, device=device
)
else:
raise NotImplementedError

Expand Down Expand Up @@ -186,12 +191,36 @@ def _output_transform(

@property
def action_spec(self) -> TensorSpec:
return _dmcontrol_to_torchrl_spec_transform(self._env.action_spec())
if self._action_spec is None:
self._action_spec = _dmcontrol_to_torchrl_spec_transform(
self._env.action_spec(), device=self.device
)
return self._action_spec

@action_spec.setter
def action_spec(self, value: TensorSpec) -> None:
self._action_spec = value

@property
def observation_spec(self) -> TensorSpec:
return _dmcontrol_to_torchrl_spec_transform(self._env.observation_spec())
if self._observation_spec is None:
self._observation_spec = _dmcontrol_to_torchrl_spec_transform(
self._env.observation_spec(), device=self.device
)
return self._observation_spec

@observation_spec.setter
def observation_spec(self, value: TensorSpec) -> None:
self._observation_spec = value

@property
def reward_spec(self) -> TensorSpec:
return _dmcontrol_to_torchrl_spec_transform(self._env.reward_spec())
if self._reward_spec is None:
self._reward_spec = _dmcontrol_to_torchrl_spec_transform(
self._env.reward_spec(), device=self.device
)
return self._reward_spec

@reward_spec.setter
def reward_spec(self, value: TensorSpec) -> None:
self._reward_spec = value
1 change: 1 addition & 0 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def __call__(self, batch: _TensorDict) -> Dict:
policy=self.policy_exploration,
n_steps=self.record_frames,
auto_reset=True,
auto_cast_to_device=True,
)
if isinstance(self.policy_exploration, torch.nn.Module):
self.policy_exploration.train()
Expand Down

0 comments on commit 476ca4e

Please sign in to comment.