Skip to content

Commit

Permalink
[BugFix] Fix flaky gym penv test (#1853)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 31, 2024
1 parent 2754200 commit 69453a6
Show file tree
Hide file tree
Showing 26 changed files with 78 additions and 73 deletions.
24 changes: 11 additions & 13 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ def rollout_consistency_assertion(
):
"""Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""

done = rollout[:, :-1]["next", done_key].squeeze(-1)
done = rollout[..., :-1]["next", done_key].squeeze(-1)
# data resulting from step, when it's not done
r_not_done = rollout[:, :-1]["next"][~done]
r_not_done = rollout[..., :-1]["next"][~done]
# data resulting from step, when it's not done, after step_mdp
r_not_done_tp1 = rollout[:, 1:][~done]
torch.testing.assert_close(
Expand All @@ -343,17 +343,15 @@ def rollout_consistency_assertion(

if done_strict and not done.any():
raise RuntimeError("No done detected, test could not complete.")

# data resulting from step, when it's done
r_done = rollout[:, :-1]["next"][done]
# data resulting from step, when it's done, after step_mdp and reset
r_done_tp1 = rollout[:, 1:][done]
assert (
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
).all(), (
f"Entries in next tensordict do not match entries in root "
f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}"
)
if done.any():
# data resulting from step, when it's done
r_done = rollout[..., :-1]["next"][done]
# data resulting from step, when it's done, after step_mdp and reset
r_done_tp1 = rollout[..., 1:][done]
# check that at least one obs after reset does not match the version before reset
assert not torch.isclose(
r_done[observation_key], r_done_tp1[observation_key]
).all()


def rand_reset(env):
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def _proc_spec(spec):
)
return BoundedTensorSpec(
shape=spec["shape"],
low=torch.tensor(spec["low"]),
high=torch.tensor(spec["high"]),
low=torch.as_tensor(spec["low"]),
high=torch.as_tensor(spec["high"]),
dtype=_DTYPE_DIR[spec["dtype"]],
)
elif spec["type"] == "Discrete":
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def _slice_data(data: TensorDict, slice_len, pad_value):
truncated,
dim=data.ndim - 1,
value=True,
index=torch.tensor(-1, device=truncated.device),
index=torch.as_tensor(-1, device=truncated.device),
)
done = data.get(("next", "done"))
data.set(("next", "truncated"), truncated)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def add(self, data: TensorDictBase) -> int:
device=data.device,
)
if data.batch_size:
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)
data_add["_rb_batch_size"] = torch.as_tensor(data.batch_size)

else:
data_add = data
Expand Down Expand Up @@ -1441,7 +1441,7 @@ def __getitem__(
if isinstance(index, slice) and index == slice(None):
return self
if isinstance(index, (list, range, np.ndarray)):
index = torch.tensor(index)
index = torch.as_tensor(index)
if isinstance(index, torch.Tensor):
if index.ndim > 1:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,10 @@ def dumps(self, path):
filename=path / "mintree.memmap",
)
mm_st.copy_(
torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)])
torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)])
)
mm_mt.copy_(
torch.tensor([self._min_tree[i] for i in range(self._max_capacity)])
torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)])
)
with open(path / "sampler_metadata.json", "w") as file:
json.dump(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def __getitem__(self, index):
if isinstance(index, slice) and index == slice(None):
return self
if isinstance(index, (list, range, np.ndarray)):
index = torch.tensor(index)
index = torch.as_tensor(index)
if isinstance(index, torch.Tensor):
if index.ndim > 1:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def _to_torch(
data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False
) -> torch.Tensor:
if isinstance(data, np.generic):
return torch.tensor(data, device=device)
return torch.as_tensor(data, device=device)
elif isinstance(data, np.ndarray):
data = torch.from_numpy(data)
elif not isinstance(data, Tensor):
data = torch.tensor(data, device=device)
data = torch.as_tensor(data, device=device)

if pin_memory:
data = data.pin_memory()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def __getstate__(self):
def dumps(self, path):
path = Path(path).absolute()
path.mkdir(exist_ok=True)
t = torch.tensor(self._current_top_values)
t = torch.as_tensor(self._current_top_values)
try:
MemoryMappedTensor.from_filename(
filename=path / "current_top_values.memmap",
Expand Down Expand Up @@ -453,7 +453,7 @@ def __getitem__(self, index):
if isinstance(index, slice) and index == slice(None):
return self
if isinstance(index, (list, range, np.ndarray)):
index = torch.tensor(index)
index = torch.as_tensor(index)
if isinstance(index, torch.Tensor):
if index.ndim > 1:
raise RuntimeError(
Expand Down
6 changes: 3 additions & 3 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def update(self, kl_values: Sequence[float]):
)
n_steps = len(kl_values)
# renormalize kls
kl_value = -torch.tensor(kl_values).mean() / self.coef
kl_value = -torch.as_tensor(kl_values).mean() / self.coef
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.coef *= mult # βₜ₊₁
Expand Down Expand Up @@ -314,10 +314,10 @@ def _get_done_status(self, generated, batch):
# of generated tokens
done_idx = torch.minimum(
(generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
torch.tensor(self.max_new_tokens) - 1,
torch.as_tensor(self.max_new_tokens) - 1,
)
truncated_idx = (
torch.tensor(self.max_new_tokens, device=generated.device).expand_as(
torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as(
done_idx
)
- 1
Expand Down
12 changes: 6 additions & 6 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,9 +1374,9 @@ def encode(
) -> torch.Tensor:
if not isinstance(val, torch.Tensor):
if ignore_device:
val = torch.tensor(val)
val = torch.as_tensor(val)
else:
val = torch.tensor(val, device=self.device)
val = torch.as_tensor(val, device=self.device)

if space is None:
space = self.space
Expand Down Expand Up @@ -1555,9 +1555,9 @@ def __init__(
dtype = torch.get_default_dtype()

if not isinstance(low, torch.Tensor):
low = torch.tensor(low, dtype=dtype, device=device)
low = torch.as_tensor(low, dtype=dtype, device=device)
if not isinstance(high, torch.Tensor):
high = torch.tensor(high, dtype=dtype, device=device)
high = torch.as_tensor(high, dtype=dtype, device=device)
if high.device != device:
high = high.to(device)
if low.device != device:
Expand Down Expand Up @@ -1857,8 +1857,8 @@ def __init__(
dtype, device = _default_dtype_and_device(dtype, device)
box = (
ContinuousBox(
torch.tensor(-np.inf, device=device).expand(shape),
torch.tensor(np.inf, device=device).expand(shape),
torch.as_tensor(-np.inf, device=device).expand(shape),
torch.as_tensor(np.inf, device=device).expand(shape),
)
if shape == _DEFAULT_SHAPE
else None
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/dm_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def _get_envs(to_dict: bool = True) -> Dict[str, Any]:

def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor:
if isinstance(array, np.ndarray):
return torch.tensor(array.copy())
return torch.as_tensor(array.copy())
else:
return torch.tensor(array)
return torch.as_tensor(array)


class DMControlWrapper(GymLikeEnv):
Expand Down
6 changes: 3 additions & 3 deletions torchrl/envs/libs/envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _transform_step_output(
f"The output of step was had {len(out)} elements, but only 4 or 5 are supported."
)
obs = self._treevalue_or_numpy_to_tensor_or_dict(obs)
reward_and_done = {self.reward_key: torch.tensor(reward)}
reward_and_done = {self.reward_key: torch.as_tensor(reward)}
reward_and_done["done"] = done
reward_and_done["terminated"] = terminated
reward_and_done["truncated"] = truncated
Expand All @@ -290,7 +290,7 @@ def _treevalue_or_numpy_to_tensor_or_dict(
if isinstance(x, treevalue.TreeValue):
ret = self._treevalue_to_dict(x)
elif not isinstance(x, dict):
ret = {"observation": torch.tensor(x)}
ret = {"observation": torch.as_tensor(x)}
else:
ret = x
return ret
Expand All @@ -304,7 +304,7 @@ def _treevalue_to_dict(
"""
import treevalue

return {k[0]: torch.tensor(v) for k, v in treevalue.flatten(tv)}
return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)}

def _set_seed(self, seed: Optional[int]):
if seed is not None:
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,7 @@ def _read_obs(self, obs, key, tensor, index):
def __call__(self, info_dict, tensordict):
terminal_obs = info_dict.get(self.backend_key[self.backend], None)
for key, item in self.info_spec.items(True, True):
key = (key,) if isinstance(key, str) else key
final_obs_buffer = item.zero()
if terminal_obs is not None:
for i, obs in enumerate(terminal_obs):
Expand Down
6 changes: 3 additions & 3 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def _init_env(self):
"info": CompositeSpec(
{
key: UnboundedContinuousTensorSpec(
shape=torch.tensor(value).shape,
shape=torch.as_tensor(value).shape,
device=self.device,
)
for key, value in info_dict[agent].items()
Expand Down Expand Up @@ -501,7 +501,7 @@ def _init_env(self):
device=self.device,
)
except AttributeError:
state_example = torch.tensor(self.state(), device=self.device)
state_example = torch.as_tensor(self.state(), device=self.device)
state_spec = UnboundedContinuousTensorSpec(
shape=state_example.shape,
dtype=state_example.dtype,
Expand Down Expand Up @@ -560,7 +560,7 @@ def _reset(
if group_info is not None:
agent_info_dict = info_dict[agent]
for agent_info, value in agent_info_dict.items():
group_info.get(agent_info)[index] = torch.tensor(
group_info.get(agent_info)[index] = torch.as_tensor(
value, device=self.device
)

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/gym_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _get_lives(self):
if callable(lives):
lives = lives()
elif isinstance(lives, list) and all(callable(_lives) for _lives in lives):
lives = torch.tensor([_lives() for _lives in lives])
lives = torch.as_tensor([_lives() for _lives in lives])
return lives

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -170,7 +170,7 @@ def _reset(self, tensordict, tensordict_reset):
end_of_life = False
tensordict_reset.set(
self.eol_key,
torch.tensor(end_of_life).expand(
torch.as_tensor(end_of_life).expand(
parent.full_done_spec[self.done_key].shape
),
)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def _init(self):
std = [0.229, 0.224, 0.225]
normalize = ObservationNorm(
in_keys=in_keys,
loc=torch.tensor(mean).view(3, 1, 1),
scale=torch.tensor(std).view(3, 1, 1),
loc=torch.as_tensor(mean).view(3, 1, 1),
scale=torch.as_tensor(std).view(3, 1, 1),
standard_normal=True,
)
transforms.append(normalize)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def find_sample_log_prob(module):
self.functional_actor.apply(find_sample_log_prob)

if not isinstance(coef, torch.Tensor):
coef = torch.tensor(coef)
coef = torch.as_tensor(coef)
self.register_buffer("coef", coef)

def _reset(
Expand Down
8 changes: 4 additions & 4 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def check_val(val):
if val is None:
return None, None, torch.finfo(torch.get_default_dtype()).max
if not isinstance(val, torch.Tensor):
val = torch.tensor(val)
val = torch.as_tensor(val)
if not val.dtype.is_floating_point:
val = val.float()
eps = torch.finfo(val.dtype).resolution
Expand Down Expand Up @@ -1626,10 +1626,10 @@ def __init__(
out_keys = copy(in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
clamp_min_tensor = (
clamp_min if isinstance(clamp_min, Tensor) else torch.tensor(clamp_min)
clamp_min if isinstance(clamp_min, Tensor) else torch.as_tensor(clamp_min)
)
clamp_max_tensor = (
clamp_max if isinstance(clamp_max, Tensor) else torch.tensor(clamp_max)
clamp_max if isinstance(clamp_max, Tensor) else torch.as_tensor(clamp_max)
)
self.register_buffer("clamp_min", clamp_min_tensor)
self.register_buffer("clamp_max", clamp_max_tensor)
Expand Down Expand Up @@ -2396,7 +2396,7 @@ def __init__(
out_keys_inv=out_keys_inv,
)
if not isinstance(standard_normal, torch.Tensor):
standard_normal = torch.tensor(standard_normal)
standard_normal = torch.as_tensor(standard_normal)
self.register_buffer("standard_normal", standard_normal)
self.eps = 1e-6

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/vc1.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _map_tv_to_torchrl(
elif isinstance(model_transforms, transforms.Normalize):
return ObservationNorm(
in_keys=in_keys,
loc=torch.tensor(model_transforms.mean).reshape(3, 1, 1),
scale=torch.tensor(model_transforms.std).reshape(3, 1, 1),
loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1),
scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1),
standard_normal=True,
)
elif isinstance(model_transforms, transforms.ToTensor):
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def _init(self):
std = [0.229, 0.224, 0.225]
normalize = ObservationNorm(
in_keys=in_keys,
loc=torch.tensor(mean).view(3, 1, 1),
scale=torch.tensor(std).view(3, 1, 1),
loc=torch.as_tensor(mean).view(3, 1, 1),
scale=torch.as_tensor(std).view(3, 1, 1),
standard_normal=True,
)
transforms.append(normalize)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ def __init__(
if isinstance(max, torch.Tensor):
max = max.to(self.device)
else:
max = torch.tensor(max, device=self.device)
max = torch.as_tensor(max, device=self.device)
if isinstance(min, torch.Tensor):
min = min.to(self.device)
else:
min = torch.tensor(min, device=self.device)
min = torch.as_tensor(min, device=self.device)
self.min = min
self.max = max
self.update(loc, scale)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/modules/models/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def __init__(
)

if sigma_init != 0.0:
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
self.register_buffer(
"sigma_init", torch.as_tensor(sigma_init, device=device)
)

@property
def sigma(self):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/planners/mppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.num_candidates = num_candidates
self.top_k = top_k
self.reward_key = reward_key
self.register_buffer("temperature", torch.tensor(temperature))
self.register_buffer("temperature", torch.as_tensor(temperature))

def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
batch_size = tensordict.batch_size
Expand Down
Loading

0 comments on commit 69453a6

Please sign in to comment.