Skip to content

Commit

Permalink
[BugFix] Fix imports
Browse files Browse the repository at this point in the history
ghstack-source-id: db85f2611c1c0b22e9179b4fdd6c2dcea78ac8dd
Pull Request resolved: #2605
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent a1e21f5 commit d90b9e3
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
5 changes: 4 additions & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,4 +1152,7 @@ def loss_and_bw(td):

if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + unknown)
pytest.main(
[__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"]
+ unknown
)
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,18 @@ def _main(argv):
"dm_control": ["dm_control"],
"gym_continuous": ["gymnasium<1.0", "mujoco"],
"rendering": ["moviepy<2.0.0"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"tests": [
"pytest",
"pyyaml",
"pytest-instafail",
"scipy",
"pytest-mock",
"pytest-cov",
"pytest-benchmark",
"pytest-rerunfailures",
"pytest-error-for-skips",
"",
],
"utils": [
"tensorboard",
"wandb",
Expand Down
16 changes: 6 additions & 10 deletions test/test_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,10 @@ def test_tensordict_tokenizer(
"Lettuce in, it's cold out here!",
]
}
if not truncation and return_tensordict and max_length == 10:
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
out = process(example)
return
out = process(example)
if return_tensordict:
if not truncation and return_tensordict and max_length == 10:
assert out.get("input_ids").shape[-1] == -1
elif return_tensordict:
assert out.get("input_ids").shape[-1] == max_length
else:
obj = out.get("input_ids")
Expand Down Expand Up @@ -346,12 +344,10 @@ def test_prompt_tensordict_tokenizer(
],
"label": ["right", "wrong", "right", "wrong", "right"],
}
if not truncation and return_tensordict and max_length == 10:
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
out = process(example)
return
out = process(example)
if return_tensordict:
if not truncation and return_tensordict and max_length == 10:
assert out.get("input_ids").shape[-1] == -1
elif return_tensordict:
assert out.get("input_ids").shape[-1] == max_length
else:
obj = out.get("input_ids")
Expand Down
19 changes: 10 additions & 9 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from __future__ import annotations

import abc
import functools
import warnings
from copy import deepcopy
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple

import numpy as np
Expand All @@ -33,6 +33,7 @@
_StepMDP,
_terminated_or_truncated,
_update_during_reset,
check_env_specs as check_env_specs_func,
get_available_libraries,
)

Expand Down Expand Up @@ -2035,7 +2036,7 @@ def _register_gym(

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2084,7 +2085,7 @@ def _register_gym( # noqa: F811

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2138,7 +2139,7 @@ def _register_gym( # noqa: F811

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2195,7 +2196,7 @@ def _register_gym( # noqa: F811

if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2254,7 +2255,7 @@ def _register_gym( # noqa: F811
)
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
entry_point = partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -2293,7 +2294,7 @@ def _register_gym( # noqa: F811
if entry_point is None:
entry_point = cls

entry_point = functools.partial(
entry_point = partial(
_TorchRLGymnasiumWrapper,
entry_point=entry_point,
info_keys=info_keys,
Expand Down Expand Up @@ -3422,11 +3423,11 @@ def _get_sync_func(policy_device, env_device):
if policy_device is not None and policy_device.type == "cuda":
if env_device is None or env_device.type == "cuda":
return torch.cuda.synchronize
return functools.partial(torch.cuda.synchronize, device=policy_device)
return partial(torch.cuda.synchronize, device=policy_device)
if env_device is not None and env_device.type == "cuda":
if policy_device is None:
return torch.cuda.synchronize
return functools.partial(torch.cuda.synchronize, device=env_device)
return partial(torch.cuda.synchronize, device=env_device)
return torch.cuda.synchronize
if torch.backends.mps.is_available():
return torch.mps.synchronize
Expand Down

0 comments on commit d90b9e3

Please sign in to comment.