Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix LSTM - VecEnv compatibility #1427

Merged
merged 31 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Aug 31, 2023
commit c7ed82e86a11303cc0d802b21324526f898f6e8a
38 changes: 29 additions & 9 deletions examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,25 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
transformed_env = TransformedEnv(base_env)
transformed_env.append_transform(
RewardScaling(
loc=0, scale=env_cfg.reward_scaling, in_keys="reward", standard_normal=False
loc=0,
scale=env_cfg.reward_scaling,
in_keys=["reward"],
standard_normal=False,
)
)
if train:
transformed_env.append_transform(
TargetReturn(
env_cfg.collect_target_return * env_cfg.reward_scaling,
out_keys=["return_to_go"],
out_keys=["return_to_go_single"],
mode=env_cfg.target_return_mode,
)
)
else:
transformed_env.append_transform(
TargetReturn(
env_cfg.eval_target_return * env_cfg.reward_scaling,
out_keys=["return_to_go"],
out_keys=["return_to_go_single"],
mode=env_cfg.target_return_mode,
)
)
Expand All @@ -107,7 +110,11 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
)
transformed_env.append_transform(obsnorm)
transformed_env.append_transform(
UnsqueezeTransform(-2, in_keys=["observation", "action", "return_to_go"])
UnsqueezeTransform(
-2,
in_keys=["observation", "action", "return_to_go_single"],
out_keys=["observation", "action", "return_to_go"],
)
)
transformed_env.append_transform(
CatFrames(
Expand Down Expand Up @@ -158,6 +165,8 @@ def make_collector(cfg, policy):
exclude_target_return = ExcludeTransform(
"return_to_go",
("next", "return_to_go"),
"return_to_go_single",
("next", "return_to_go_single"),
("next", "action"),
("next", "observation"),
"scale",
Expand All @@ -183,9 +192,15 @@ def make_collector(cfg, policy):


def make_offline_replay_buffer(rb_cfg, reward_scaling):
r2g = Reward2GoTransform(gamma=1.0, in_keys=["reward"], out_keys=["return_to_go"])
r2g = Reward2GoTransform(
gamma=1.0, in_keys=["reward"], out_keys=["return_to_go_single"]
)
reward_scale = RewardScaling(
loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False
loc=0,
scale=reward_scaling,
in_keys="return_to_go_single",
out_keys=["return_to_go"],
standard_normal=False,
)
crop_seq = RandomCropTensorDict(sub_seq_len=rb_cfg.stacked_frames, sample_dim=-1)

Expand Down Expand Up @@ -230,12 +245,17 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling):


def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001):
r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"])
r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go_single"])
reward_scale = RewardScaling(
loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False
loc=0,
scale=reward_scaling,
in_keys=["return_to_go_single"],
out_keys=["return_to_go"],
standard_normal=False,
)
catframes = CatFrames(
in_keys=["return_to_go"],
in_keys=["return_to_go_single"],
out_keys=["return_to_go"],
N=rb_cfg.stacked_frames,
dim=-2,
padding="zeros",
Expand Down
14 changes: 6 additions & 8 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,14 @@ def read_done(self, done):
"""
return done, done

def read_reward(self, total_reward, step_reward):
"""Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two.
def read_reward(self, reward):
"""Reads the reward and maps it to the reward space.

Args:
total_reward (torch.Tensor or TensorDict): total reward so far in the step
step_reward (reward in the format provided by the inner env): reward of this particular step
reward (torch.Tensor or TensorDict): reward to be mapped.

"""
return (
total_reward + step_reward
) # self.reward_spec.encode(step_reward, ignore_device=True)
return self.reward_spec.encode(reward)

def read_obs(
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
Expand Down Expand Up @@ -214,7 +211,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if _reward is None:
_reward = self.reward_spec.zero()

reward = self.read_reward(reward, _reward)
reward = reward + _reward

if isinstance(done, bool) or (
isinstance(done, np.ndarray) and not len(done)
Expand All @@ -224,6 +221,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if do_break:
break

reward = self.read_reward(reward)
obs_dict = self.read_obs(obs)

if reward is None:
Expand Down
3 changes: 1 addition & 2 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# prepare inputs
state = _tensordict_to_object(tensordict.get("state"), self._state_example)
action = self.read_action(tensordict.get("action"))
reward = self.reward_spec.zero()

# flatten batch size into vector
state = _tree_flatten(state, self.batch_size)
Expand All @@ -268,7 +267,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# collect outputs
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
reward = self.read_reward(reward, np.asarray(timestep.reward))
reward = self.read_reward(np.asarray(timestep.reward))
done = timestep.step_type == self.lib.types.StepType.LAST
done = _ndarray_to_tensor(done).view(torch.bool).to(self.device)

Expand Down
44 changes: 25 additions & 19 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ def __init__(
out_keys_inv: Optional[Sequence[NestedKey]] = None,
):
super().__init__()
if isinstance(in_keys, str):
if isinstance(in_keys, (str, tuple)):
in_keys = [in_keys]
if isinstance(out_keys, (str, tuple)):
out_keys = [out_keys]

self.in_keys = in_keys
if out_keys is None:
Expand Down Expand Up @@ -1132,17 +1134,17 @@ def __init__(
self.mode = mode

def reset(self, tensordict: TensorDict):
init_target_return = torch.full(
size=(*tensordict.batch_size, 1),
fill_value=self.target_return,
dtype=torch.float32,
device=tensordict.device,
)

for out_key in self.out_keys:
target_return = tensordict.get(out_key, default=None)

if target_return is None:
init_target_return = torch.full(
size=(*tensordict.batch_size, 1),
fill_value=self.target_return,
dtype=torch.float32,
device=tensordict.device,
)
target_return = init_target_return

tensordict.set(
Expand Down Expand Up @@ -1173,18 +1175,18 @@ def _apply_transform(
self, reward: torch.Tensor, target_return: torch.Tensor
) -> torch.Tensor:
if self.mode == "reduce":
if reward.ndim == 1 and target_return.ndim == 2:
# if target is stacked
target_return = target_return[-1] - reward
else:
target_return = target_return - reward
# if reward.ndim == 1 and target_return.ndim == 2:
# # if target is stacked
# target_return = target_return[-1] - reward
# else:
target_return = target_return - reward
return target_return
elif self.mode == "constant":
if reward.ndim == 1 and target_return.ndim == 2:
# if target is stacked
target_return = target_return[-1]
else:
target_return = target_return
# if reward.ndim == 1 and target_return.ndim == 2:
# # if target is stacked
# target_return = target_return[-1]
# else:
target_return = target_return
return target_return
else:
raise ValueError("Unknown mode: {}".format(self.mode))
Expand Down Expand Up @@ -2127,7 +2129,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
for in_key, out_key in zip(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = f"_cat_buffers_{in_key}"
data = tensordict[in_key]
data = tensordict.get(in_key)
d = data.size(self.dim)
buffer = getattr(self, buffer_name)
if isinstance(buffer, torch.nn.parameter.UninitializedBuffer):
Expand Down Expand Up @@ -2297,11 +2299,15 @@ def __init__(
loc: Union[float, torch.Tensor],
scale: Union[float, torch.Tensor],
in_keys: Optional[Sequence[NestedKey]] = None,
out_keys: Optional[Sequence[NestedKey]] = None,
standard_normal: bool = False,
):
if in_keys is None:
in_keys = ["reward"]
super().__init__(in_keys=in_keys)
if out_keys is None:
out_keys = in_keys

super().__init__(in_keys=in_keys, out_keys=out_keys)
if not isinstance(standard_normal, torch.Tensor):
standard_normal = torch.tensor(standard_normal)
self.register_buffer("standard_normal", standard_normal)
Expand Down
1 change: 0 additions & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from tensordict.tensordict import (
LazyStackedTensorDict,
NestedKey,
TensorDict,
TensorDictBase,
)

Expand Down
7 changes: 2 additions & 5 deletions torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from tensordict import TensorDict
from tensordict._tensordict import _unravel_key_to_tuple, unravel_keys
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl._utils import _check_for_faulty_process, VERBOSE
Expand All @@ -35,7 +35,6 @@
from torchrl.envs.env_creator import get_env_metadata

from torchrl.envs.utils import (
_replace_last,
_set_single_key,
_sort_keys,
clear_mpi_env_vars,
Expand Down Expand Up @@ -340,9 +339,7 @@ def _create_td(self) -> None:
for key in self.output_spec["full_observation_spec"].keys(True, True):
self._env_output_keys.append(key)
self._env_obs_keys.append(key)
self._env_output_keys += [
unravel_keys(("next", key)) for key in self.reward_keys + self.done_keys
]
self._env_output_keys += self.reward_keys + self.done_keys
else:
env_input_keys = set()
for meta_data in self.meta_data:
Expand Down