Skip to content

Commit

Permalink
[BugFix] Fix old deps tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 134d129514f39538fd9c9f94ce07ad939831c8aa
Pull Request resolved: #2500
  • Loading branch information
vmoens committed Oct 21, 2024
1 parent 9f6c21f commit 56b0b9a
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 17 deletions.
12 changes: 10 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from packaging import version
from tensordict import (
assert_allclose_td,
LazyStackedTensorDict,
Expand Down Expand Up @@ -106,6 +107,7 @@
IS_OSX = sys.platform == "darwin"
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


class WrappablePolicy(nn.Module):
Expand Down Expand Up @@ -2654,6 +2656,9 @@ def test_dynamic_multiasync_collector(self):
assert data.names[-1] == "time"


@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
class TestCompile:
@pytest.mark.parametrize(
"collector_cls",
Expand Down Expand Up @@ -2996,8 +3001,9 @@ def __deepcopy_error__(*args, **kwargs):
raise RuntimeError("deepcopy not allowed")


@pytest.mark.filterwarnings("error")
@pytest.mark.filterwarnings("ignore:Tensordict is registered in PyTree")
@pytest.mark.filterwarnings(
"error::UserWarning", "ignore:Tensordict is registered in PyTree:UserWarning"
)
@pytest.mark.parametrize(
"collector_type",
[
Expand All @@ -3016,6 +3022,8 @@ def test_no_deepcopy_policy(collector_type):
# If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we
# can trust that the user knows what to do).

# warnings.warn("Tensordict is registered in PyTree", category=UserWarning)

shared_device = torch.device("cpu")
if torch.cuda.is_available():
original_device = torch.device("cuda:0")
Expand Down
7 changes: 5 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv
from packaging import version

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
Expand Down Expand Up @@ -146,7 +147,7 @@
_split_and_pad_sequence,
)

TORCH_VERSION = torch.__version__
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

# Capture all warnings
pytestmark = [
Expand Down Expand Up @@ -15731,7 +15732,9 @@ def __init__(self):
assert p.device == dest


@pytest.mark.skipif(TORCH_VERSION < "2.5", reason="requires torch>=2.5")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
def test_exploration_compile():
m = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
Expand Down
3 changes: 2 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from _utils_internal import get_default_devices
from tensordict import TensorDictBase
from torch import autograd, nn
from torch.utils._pytree import tree_map
from torchrl.modules import (
NormalParamWrapper,
OneHotCategorical,
Expand Down Expand Up @@ -182,7 +183,7 @@ class TestTruncatedNormal:
@pytest.mark.parametrize("device", get_default_devices())
def test_truncnormal(self, min, max, vecs, upscale, shape, device):
torch.manual_seed(0)
*vecs, min, max, vecs, upscale = torch.utils._pytree.tree_map(
*vecs, min, max, vecs, upscale = tree_map(
lambda t: torch.as_tensor(t, device=device),
(*vecs, min, max, vecs, upscale),
)
Expand Down
2 changes: 1 addition & 1 deletion test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def test_consistent_dropout(self, dropout_p, parallel_spec, device):

# NOTE: Please only put a module with one dropout layer.
# That's how this test is constructed anyways.
@torch.no_grad
@torch.no_grad()
def inner_verify_routine(module, env):
# Perform transitions.
collector = SyncDataCollector(
Expand Down
2 changes: 1 addition & 1 deletion test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
make_dqn_actor,
)

TORCH_VERSION = version.parse(torch.__version__)
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
if TORCH_VERSION < version.parse("1.12.0"):
UNSQUEEZE_SINGLETON = True
else:
Expand Down
14 changes: 8 additions & 6 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
)
from tensordict.base import NO_DEFAULT
from tensordict.nn import CudaGraphModule, TensorDictModule
from tensordict.utils import Buffer
from torch import multiprocessing as mp
from torch.nn import Parameter
from torch.utils.data import IterableDataset

from torchrl._utils import (
Expand Down Expand Up @@ -202,17 +204,17 @@ def map_weight(
policy_device=policy_device,
):

is_param = isinstance(weight, nn.Parameter)
is_buffer = isinstance(weight, nn.Buffer)
is_param = isinstance(weight, Parameter)
is_buffer = isinstance(weight, Buffer)
weight = weight.data
if weight.device != policy_device:
weight = weight.to(policy_device)
elif weight.device.type in ("cpu", "mps"):
weight = weight.share_memory_()
if is_param:
weight = nn.Parameter(weight, requires_grad=False)
weight = Parameter(weight, requires_grad=False)
elif is_buffer:
weight = nn.Buffer(weight)
weight = Buffer(weight)
return weight

# Create a stateless policy, then populate this copy with params on device
Expand Down Expand Up @@ -3089,12 +3091,12 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):


def _make_meta_params(param):
is_param = isinstance(param, nn.Parameter)
is_param = isinstance(param, Parameter)

pd = param.detach().to("meta")

if is_param:
pd = nn.Parameter(pd, requires_grad=False)
pd = Parameter(pd, requires_grad=False)
return pd


Expand Down
5 changes: 2 additions & 3 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensordict.nn.utils import _set_dispatch_td_nn_modules
from tensordict.utils import expand_as_right, expand_right
from torch import Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
Expand Down Expand Up @@ -319,9 +320,7 @@ def dim_extend(self, value):
def _transpose(self, data):
if is_tensor_collection(data):
return data.transpose(self.dim_extend, 0)
return torch.utils._pytree.tree_map(
lambda x: x.transpose(self.dim_extend, 0), data
)
return tree_map(lambda x: x.transpose(self.dim_extend, 0), data)

def _get_collate_fn(self, collate_fn):
self._collate_fn = (
Expand Down
30 changes: 29 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,16 +1367,44 @@ def _collate_list_tensordict(x):
return out


@implement_for("torch", "2.4")
def _stack_anything(data):
if is_tensor_collection(data[0]):
return LazyStackedTensorDict.maybe_dense_stack(data)
return torch.utils._pytree.tree_map(
return tree_map(
lambda *x: torch.stack(x),
*data,
is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x),
)


@implement_for("torch", None, "2.4")
def _stack_anything(data): # noqa: F811
from tensordict import _pytree

if not _pytree.PYTREE_REGISTERED_TDS:
raise RuntimeError(
"TensorDict is not registered within PyTree. "
"If you see this error, it means tensordicts instances cannot be natively stacked using tree_map. "
"To solve this issue, (a) upgrade pytorch to a version > 2.4, or (b) make sure TensorDict is registered in PyTree. "
"If this error persists, open an issue on https://github.com/pytorch/rl/issues"
)
if is_tensor_collection(data[0]):
return LazyStackedTensorDict.maybe_dense_stack(data)
flat_trees = []
spec = None
for d in data:
flat_tree, spec = tree_flatten(d)
flat_trees.append(flat_tree)

leaves = []
for leaf in zip(*flat_trees):
leaf = torch.stack(leaf)
leaves.append(leaf)

return tree_unflatten(leaves, spec)


def _collate_id(x):
return x

Expand Down

1 comment on commit 56b0b9a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 56b0b9a Previous: 9f6c21f Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 36.51975070179481 iter/sec (stddev: 0.1643548163791813) 222.7289513749851 iter/sec (stddev: 0.0009971111679747764) 6.10

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.