From 0f29c7e933662da5989ba7f9d9fd9ba9729a24f2 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 24 Oct 2024 10:34:43 -0700 Subject: [PATCH] [Feature] Avoid some recompiles of `ReplayBuffer.extend/sample` This change avoids recompiles for back-to-back calls to `ReplayBuffer.extend` and `.sample` in cases where `LazyTensorStorage`, `RoundRobinWriter`, and `RandomSampler` are used and the data type is either tensor or pytree. ghstack-source-id: d306cb9f47bdbfb81988589f4f4d923c8427eaa0 Pull Request resolved: https://github.com/pytorch/rl/pull/2504 --- test/_utils_internal.py | 27 +++++++++ test/test_rb.py | 77 ++++++++++++++++++++++++- test/test_utils.py | 33 ++++++++++- torchrl/data/replay_buffers/storages.py | 19 +++++- 4 files changed, 153 insertions(+), 3 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 51535afa606..48492459315 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -5,10 +5,12 @@ from __future__ import annotations import contextlib +import logging import os import os.path import time +import unittest from functools import wraps # Get relative file path @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs): return deco_retry +# After calling this function, any log record whose name contains 'record_name' +# and is emitted from the logger that has qualified name 'logger_qname' is +# appended to the 'records' list. +# NOTE: This function is based on testing utilities for 'torch._logging' +def capture_log_records(records, logger_qname, record_name): + assert isinstance(records, list) + logger = logging.getLogger(logger_qname) + + class EmitWrapper: + def __init__(self, old_emit): + self.old_emit = old_emit + + def __call__(self, record): + nonlocal records + self.old_emit(record) + if record_name in record.name: + records.append(record) + + for handler in logger.handlers: + new_emit = EmitWrapper(handler.emit) + contextlib.ExitStack().enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + @pytest.fixture def dtype_fixture(): dtype = torch.get_default_dtype() diff --git a/test/test_rb.py b/test/test_rb.py index 0e10f534728..490dd1f56df 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -17,7 +17,12 @@ import pytest import torch -from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc +from _utils_internal import ( + capture_log_records, + CARTPOLE_VERSIONED, + get_default_devices, + make_tc, +) from mocking_classes import CountingEnv from packaging import version @@ -111,6 +116,7 @@ _has_gym = importlib.util.find_spec("gym") is not None _has_snapshot = importlib.util.find_spec("torchsnapshot") is not None _os_is_windows = sys.platform == "win32" +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) torch_2_3 = version.parse( ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) @@ -399,6 +405,75 @@ def data_iter(): ) if cond else contextlib.nullcontext(): rb.extend(data2) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + # Compiling on Windows requires "cl" compiler to be installed. + # + # Our Windows CI jobs do not have "cl", so skip this test. + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + def test_extend_sample_recompile( + self, rb_type, sampler, writer, storage, size, datatype + ): + if rb_type is not ReplayBuffer: + pytest.skip( + "Only replay buffer of type 'ReplayBuffer' is currently supported." + ) + if sampler is not RandomSampler: + pytest.skip("Only sampler of type 'RandomSampler' is currently supported.") + if storage is not LazyTensorStorage: + pytest.skip( + "Only storage of type 'LazyTensorStorage' is currently supported." + ) + if writer is not RoundRobinWriter: + pytest.skip( + "Only writer of type 'RoundRobinWriter' is currently supported." + ) + if datatype == "tensordict": + pytest.skip("'tensordict' datatype is not currently supported.") + + torch._dynamo.reset_code_caches() + + storage_size = 10 * size + rb = self._get_rb( + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=storage_size, + ) + data_size = size + data = self._get_data(datatype, size=data_size) + + @torch.compile + def extend_and_sample(data): + rb.extend(data) + return rb.sample() + + # Number of times to extend the replay buffer + num_extend = 30 + + # NOTE: The first two calls to 'extend' and 'sample' currently cause + # recompilations, so avoid capturing those for now. + num_extend_before_capture = 2 + + for _ in range(num_extend_before_capture): + extend_and_sample(data) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend - num_extend_before_capture): + extend_and_sample(data) + + assert len(rb) == storage_size + assert len(records) == 0 + + finally: + torch._logging.set_logs() + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( diff --git a/test/test_utils.py b/test/test_utils.py index 4224a36b54f..6537c19ff54 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,11 +14,14 @@ import torch -from _utils_internal import get_default_devices +from _utils_internal import capture_log_records, get_default_devices +from packaging import version from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + @pytest.mark.parametrize("value", ["True", "1", "true"]) def test_get_binary_env_var_positive(value): @@ -380,6 +383,34 @@ def test_rng_decorator(device): torch.testing.assert_close(s0b, s1b) +# Check that 'capture_log_records' captures records emitted when torch +# recompiles a function. +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" +) +def test_capture_log_records_recompile(): + torch.compiler.reset() + + # This function recompiles each time it is called with a different string + # input. + @torch.compile + def str_to_tensor(s): + return bytes(s, "utf8") + + str_to_tensor("a") + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + str_to_tensor("b") + + finally: + torch._logging.set_logs() + + assert len(records) == 1 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 4d47fd5265d..beab68971b5 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -144,12 +144,29 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): ... + # NOTE: This property is used to enable compiled Storages. Calling + # `len(self)` on a TensorStorage should normally cause a graph break since + # it uses a `mp.Value`, and it does cause a break when the `len(self)` call + # happens within a method of TensorStorage itself. However, when the + # `len(self)` call happens in the Storage base class, for an unknown reason + # the compiler doesn't seem to recognize that there should be a graph break, + # and the lack of a break causes a recompile each time `len(self)` is called + # in this context. Also for an unknown reason, we can force the graph break + # to happen if we wrap the `len(self)` call with a `property`-decorated + # function. For another unknown reason, if we change + # `TensorStorage._len_value` from `mp.Value` to int, it seems like there + # should no longer be any need to recompile, but recompiles happen anyway. + # Ideally, this should all be investigated and understood in the future. + @property + def len(self): + return len(self) + def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim if self.ndim == 1: return torch.randint( 0, - len(self), + self.len, (batch_size,), generator=self._rng, device=getattr(self, "device", None),