From 2332909acb7c56393f6b080020b753d7f18aa9b7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 18:05:12 -0700 Subject: [PATCH] [Feature] Make benchmarked losses compatible with torch.compile ghstack-source-id: 825ded593dcffcecf626a705d8b7c0c5e0839719 Pull Request resolved: https://github.com/pytorch/rl/pull/2405 --- .github/unittest/linux/scripts/run_all.sh | 6 +- .../linux_examples/scripts/run_all.sh | 8 +- .../linux_libs/scripts_brax/install.sh | 2 +- .../linux_libs/scripts_openx/install.sh | 4 +- .../linux_libs/scripts_rlhf/install.sh | 8 +- .../linux_libs/scripts_vd4rl/install.sh | 2 +- .../linux_olddeps/scripts_gym_0_13/install.sh | 2 +- .../unittest/linux_optdeps/scripts/install.sh | 2 +- .github/workflows/benchmarks.yml | 2 + .github/workflows/benchmarks_pr.yml | 2 + benchmarks/test_objectives_benchmarks.py | 423 ++++++++++++++++-- test/test_cost.py | 65 ++- torchrl/__init__.py | 44 ++ torchrl/data/tensor_specs.py | 112 ++--- torchrl/envs/transforms/transforms.py | 8 +- torchrl/modules/distributions/continuous.py | 108 ++++- torchrl/modules/distributions/utils.py | 194 +++++--- torchrl/objectives/a2c.py | 30 +- torchrl/objectives/common.py | 33 +- torchrl/objectives/cql.py | 40 +- torchrl/objectives/iql.py | 15 +- torchrl/objectives/ppo.py | 36 +- torchrl/objectives/redq.py | 45 +- torchrl/objectives/td3.py | 4 +- torchrl/objectives/td3_bc.py | 4 +- torchrl/objectives/utils.py | 18 +- 26 files changed, 940 insertions(+), 277 deletions(-) diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 3257adf8c63..07de5e33099 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -127,13 +127,13 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 37719e51074..1a713ce6870 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_brax/install.sh b/.github/unittest/linux_libs/scripts_brax/install.sh index 80efdc536ab..20a2643dac8 100755 --- a/.github/unittest/linux_libs/scripts_brax/install.sh +++ b/.github/unittest/linux_libs/scripts_brax/install.sh @@ -34,7 +34,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U else pip3 install torch --index-url https://download.pytorch.org/whl/cu121 fi diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh index 1be73fc1de0..c657fd48b46 100755 --- a/.github/unittest/linux_libs/scripts_openx/install.sh +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -37,9 +37,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -U fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh index d0363186c1a..9a5cf82074b 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh @@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu121" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cu121 fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index 1be73fc1de0..256f8d065f6 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -37,7 +37,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 fi diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index 7b7c857c37a..c1dde8bb7d0 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y else - conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 numpy-base==1.26 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/unittest/linux_optdeps/scripts/install.sh b/.github/unittest/linux_optdeps/scripts/install.sh index 8ccbfbb8e19..be9fd8df5aa 100755 --- a/.github/unittest/linux_optdeps/scripts/install.sh +++ b/.github/unittest/linux_optdeps/scripts/install.sh @@ -20,7 +20,7 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U # install tensordict if [[ "$RELEASE" == 0 ]]; then diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 5591ab5787b..7d8b714ad4d 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -40,6 +40,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest --benchmark-json output.json - name: Store benchmark results if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} @@ -107,6 +108,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python3 -m pytest --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 5aeb09406ea..fa1b8037ecb 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -46,6 +46,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 RUN_BENCHMARK="pytest --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} @@ -125,6 +126,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 RUN_BENCHMARK="pytest --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d2f0d11643a..d07b40595bc 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -6,9 +6,11 @@ import pytest import torch +from packaging import version from tensordict import TensorDict from tensordict.nn import ( + InteractionType, NormalParamExtractor, ProbabilisticTensorDictModule as ProbMod, ProbabilisticTensorDictSequential as ProbSeq, @@ -42,6 +44,20 @@ vec_td_lambda_return_estimate, ) +TORCH_VERSION = torch.__version__ +FULLGRAPH = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse( + "2.5.0" +) # Anything from 2.5, incl. nightlies, allows for fullgraph + + +@pytest.fixture(scope="module") +def set_default_device(): + cur_device = torch.get_default_device() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + torch.set_default_device(device) + yield + torch.set_default_device(cur_device) + class setup_value_fn: def __init__(self, has_lmbda, has_state_value): @@ -137,7 +153,26 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps): ) -def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128): +def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3): + if compile: + if isinstance(compile, str): + fn = torch.compile(fn, mode=compile, fullgraph=fullgraph) + else: + fn = torch.compile(fn, fullgraph=fullgraph) + + for _ in range(warmup): + fn(td) + + return fn + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_dqn_speed( + benchmark, backward, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128 +): + if compile: + torch._dynamo.reset_code_caches() net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells) action_space = "one-hot" mod = QValueActor(net, in_keys=["obs"], action_space=action_space) @@ -155,10 +190,36 @@ def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128): [batch], ) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() -def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_ddpg_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -200,10 +261,36 @@ def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden loss = DDPGLoss(actor, value) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() -def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_sac_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -245,21 +332,48 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.clone())) loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: -def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_redq_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -302,23 +416,50 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.copy())) loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + if backward: + def loss_and_bw(td): + losses = loss(td) + totalloss = sum( + [val for key, val in losses.items() if key.startswith("loss")] + ) + totalloss.backward() + + loss_and_bw(td) + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_redq_deprec_speed( - benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 ): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -361,21 +502,48 @@ def test_redq_deprec_speed( out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.copy())) loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) -def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_td3_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -417,14 +585,16 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, return_log_prob=True, + default_interaction_type=InteractionType.DETERMINISTIC, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.clone())) loss = TD3Loss( actor, @@ -433,10 +603,36 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= ) loss(td) - benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10) -def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_cql_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -475,24 +671,59 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.copy())) loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_a2c_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -533,7 +764,10 @@ def test_a2c_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -544,12 +778,44 @@ def test_a2c_speed( advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_ppo_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -590,7 +856,10 @@ def test_ppo_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -601,12 +870,44 @@ def test_ppo_speed( advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) - benchmark(loss, td) + + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_reinforce_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -647,7 +948,10 @@ def test_reinforce_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -658,12 +962,44 @@ def test_reinforce_speed( advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_iql_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -710,7 +1046,10 @@ def test_iql_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) value = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -723,7 +1062,27 @@ def test_iql_speed( loss = IQLLoss(actor_network=actor, value_network=value, qvalue_network=qvalue) loss(td) - benchmark(loss, td) + + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) if __name__ == "__main__": diff --git a/test/test_cost.py b/test/test_cost.py index b11cec924e3..1c00d4d965f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -146,6 +146,7 @@ _split_and_pad_sequence, ) +TORCH_VERSION = torch.__version__ # Capture all warnings pytestmark = [ @@ -15282,7 +15283,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: class MyLoss3(MyLoss2): @dataclass class _AcceptedKeys: - some_key = "some_value" + some_key: str = "some_value" loss_module = MyLoss3() assert loss_module.tensor_keys.some_key == "some_value" @@ -15644,6 +15645,68 @@ def __init__(self): assert p.device == dest +@pytest.mark.skipif(TORCH_VERSION < "2.5", reason="requires torch>=2.5") +def test_exploration_compile(): + m = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + out_keys=["sample"], + distribution_class=torch.distributions.Normal, + ) + + # class set_exploration_type_random(set_exploration_type): + # __init__ = object.__init__ + # type = ExplorationType.RANDOM + it = exploration_type() + + @torch.compile(fullgraph=True) + def func(t): + with set_exploration_type(ExplorationType.RANDOM): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] != t1["sample"]).any() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + def func(t): + with set_exploration_type(ExplorationType.MEAN): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] == t1["sample"]).all() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + @set_exploration_type(ExplorationType.RANDOM) + def func(t): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] != t1["sample"]).any() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + @set_exploration_type(ExplorationType.MEAN) + def func(t): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] == t1["sample"]).all() + assert it == exploration_type() + + def test_loss_exploration(): class DummyLoss(LossModule): def forward(self, td, mode): diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 25103423cac..cbd7b66a65e 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os +import weakref from warnings import warn import torch @@ -10,6 +11,7 @@ from tensordict import set_lazy_legacy from torch import multiprocessing as mp +from torch.distributions.transforms import _InverseTransform, ComposeTransform set_lazy_legacy(False).set() @@ -51,3 +53,45 @@ filter_warnings_subprocess = True _THREAD_POOL_INIT = torch.get_num_threads() + + +# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home +@property +def _inv(self): + """Patched version of Transform.inv. + + Returns the inverse :class:`Transform` of this transform. + + This should satisfy ``t.inv.inv is t``. + """ + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + if not torch.compiler.is_dynamo_compiling(): + self._inv = weakref.ref(inv) + return inv + + +torch.distributions.transforms.Transform.inv = _inv + + +@property +def _inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = ComposeTransform([p.inv for p in reversed(self.parts)]) + if not torch.compiler.is_dynamo_compiling(): + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + else: + # We need inv.inv to be equal to self, but weakref can cause a graph break + inv._inv = lambda out=self: out + + return inv + + +ComposeTransform.inv = _inv diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 60c1009990e..98a32de5715 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -83,6 +83,12 @@ ) +def _size(list_of_ints): + # ensures that np int64 elements don't slip through Size + # see https://github.com/pytorch/pytorch/issues/127194 + return torch.Size([int(i) for i in list_of_ints]) + + # Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default class _NoDefault(enum.IntEnum): ZERO = 0 @@ -640,7 +646,7 @@ def __ne__(self, other): def __setattr__(self, key, value): if key == "shape": - value = torch.Size(value) + value = _size(value) super().__setattr__(key, value) def to_numpy( @@ -686,7 +692,7 @@ def ndimension(self) -> int: @property def _safe_shape(self) -> torch.Size: """Returns a shape where all heterogeneous values are replaced by one (to be expandable).""" - return torch.Size([int(v) if v >= 0 else 1 for v in self.shape]) + return _size([int(v) if v >= 0 else 1 for v in self.shape]) @abc.abstractmethod def index( @@ -752,9 +758,7 @@ def make_neg_dim(self, dim: int) -> T: dim = self.ndim + dim if dim < 0 or dim > self.ndim - 1: raise ValueError(f"dim={dim} is out of bound for ndim={self.ndim}") - self.shape = torch.Size( - [s if i != dim else -1 for i, s in enumerate(self.shape)] - ) + self.shape = _size([s if i != dim else -1 for i, s in enumerate(self.shape)]) @overload def reshape(self, shape) -> T: @@ -914,7 +918,7 @@ def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """ if shape is None: - shape = torch.Size([]) + shape = _size([]) return torch.zeros( (*shape, *self._safe_shape), dtype=self.dtype, device=self.device ) @@ -1318,7 +1322,7 @@ def shape(self): if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) - return torch.Size(shape) + return _size(shape) @shape.setter def shape(self, shape): @@ -1330,7 +1334,7 @@ def shape(self, shape): raise RuntimeError( f"The shape attribute mismatches between the input {shape} and self.shape={self.shape}." ) - shape_strip = torch.Size([s for i, s in enumerate(self.shape) if i != self.dim]) + shape_strip = _size([s for i, s in enumerate(self.shape) if i != self.dim]) for spec in self._specs: spec.shape = shape_strip @@ -1479,9 +1483,9 @@ def __init__( self.use_register = use_register space = CategoricalBox(n) if shape is None: - shape = torch.Size((space.n,)) + shape = _size((space.n,)) else: - shape = torch.Size(shape) + shape = _size(shape) if not len(shape) or shape[-1] != space.n: raise ValueError( f"The last value of the shape must match n for transform of type {self.__class__}. " @@ -1667,7 +1671,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) + shape = _size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: n = self.space.n @@ -1746,7 +1750,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.space.n, - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, use_register=self.use_register, @@ -1997,9 +2001,9 @@ def __init__( ) if shape is not None and not isinstance(shape, torch.Size): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) else: - shape = torch.Size(list(shape)) + shape = _size(list(shape)) if shape is not None: shape_corr = _remove_neg_shapes(shape) else: @@ -2032,9 +2036,9 @@ def __init__( shape = low.shape else: if isinstance(shape_corr, float): - shape_corr = torch.Size([shape_corr]) + shape_corr = _size([shape_corr]) elif not isinstance(shape_corr, torch.Size): - shape_corr = torch.Size(shape_corr) + shape_corr = _size(shape_corr) shape_corr_err_msg = ( f"low and shape_corr mismatch, got {low.shape} and {shape_corr}" ) @@ -2167,7 +2171,7 @@ def unbind(self, dim: int = 0): def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) a, b = self.space if self.dtype in (torch.float, torch.double, torch.half): shape = [*shape, *self._safe_shape] @@ -2191,9 +2195,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: else: mini = self.space.low interval = maxi - mini - r = torch.rand( - torch.Size([*shape, *self._safe_shape]), device=interval.device - ) + r = torch.rand(_size([*shape, *self._safe_shape]), device=interval.device) r = interval * r r = self.space.low + r r = r.to(self.dtype).to(self.device) @@ -2284,7 +2286,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): "Pending resolution of https://github.com/pytorch/pytorch/issues/100080." ) - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) # Expand is required as pytorch.tensor indexing return self.__class__( low=self.space.low[idx].clone().expand(indexed_shape), @@ -2365,7 +2367,7 @@ def __init__( **kwargs, ): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) _, device = _default_dtype_and_device(None, device) domain = None @@ -2424,7 +2426,7 @@ def is_in(self, val: torch.Tensor) -> bool: def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] - shape = torch.Size(shape) + shape = _size(shape) if not all( (old == 1) or (old == new) for old, new in zip(self.shape, shape[-len(self.shape) :]) @@ -2447,7 +2449,7 @@ def _unflatten(self, dim, sizes): def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int = 0): @@ -2548,7 +2550,7 @@ def __init__( **kwargs, ): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) dtype, device = _default_dtype_and_device(dtype, device) if dtype == torch.bool: @@ -2596,7 +2598,7 @@ def clone(self) -> Unbounded: def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) shape = [*shape, *self.shape] if self.dtype.is_floating_point: return torch.randn(shape, device=self.device, dtype=self.dtype) @@ -2637,7 +2639,7 @@ def _unflatten(self, dim, sizes): def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int = 0): @@ -2754,9 +2756,9 @@ def __init__( self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) if shape is None: - shape = torch.Size((sum(nvec),)) + shape = _size((sum(nvec),)) else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != sum(nvec): raise ValueError( f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. " @@ -2857,7 +2859,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) + shape = _size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: @@ -3133,7 +3135,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( nvec=self.nvec, - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) @@ -3198,7 +3200,7 @@ def __init__( mask: torch.Tensor | None = None, ): if shape is None: - shape = torch.Size([]) + shape = _size([]) dtype, device = _default_dtype_and_device(dtype, device) space = CategoricalBox(n) super().__init__( @@ -3241,12 +3243,12 @@ def update_mask(self, mask): def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) if self.mask is None: return torch.randint( 0, self.space.n, - torch.Size([*shape, *self.shape]), + _size([*shape, *self.shape]), device=self.device, dtype=self.dtype, ) @@ -3266,7 +3268,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: if self.mask is None: return val.clamp_(min=0, max=self.space.n - 1) shape = self.mask.shape - shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) oob = ~gathered.all(-1) @@ -3285,14 +3287,14 @@ def is_in(self, val: torch.Tensor) -> bool: return False return (0 <= val).all() and (val < self.space.n).all() shape = self.mask.shape - shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) return gathered.all() def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__( n=self.space.n, shape=indexed_shape, @@ -3535,9 +3537,9 @@ def __init__( if n is None: n = shape[-1] if shape is None or not len(shape): - shape = torch.Size((n,)) + shape = _size((n,)) else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != n: raise ValueError( f"The last value of the shape must match n for spec {self.__class__}. " @@ -3636,7 +3638,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.shape[-1], - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) @@ -3697,7 +3699,7 @@ def __init__( if shape is None: shape = nvec.shape else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != nvec.shape[-1]: raise ValueError( f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. " @@ -3827,7 +3829,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: *self.shape[:-1], ) x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim) - if self.remove_singleton and self.shape == torch.Size([1]): + if self.remove_singleton and self.shape == _size([1]): x = x.squeeze(-1) return x @@ -4174,7 +4176,7 @@ def shape(self, value: torch.Size): f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and " f"Composite.shape={self.shape}." ) - self._shape = torch.Size(value) + self._shape = _size(value) def is_empty(self): """Whether the composite spec contains specs or not.""" @@ -4211,8 +4213,8 @@ def __init__( shape = batch_size if shape is None: - shape = torch.Size(()) - self._shape = torch.Size(shape) + shape = _size(()) + self._shape = _size(shape) self._specs = {} for key, value in kwargs.items(): self.set(key, value) @@ -4384,7 +4386,7 @@ def encode( if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals else: - out = TensorDict._new_unsafe({}, torch.Size([])) + out = TensorDict._new_unsafe({}, _size([])) for key, item in vals.items(): if item is None: raise RuntimeError( @@ -4444,7 +4446,7 @@ def project(self, val: TensorDictBase) -> TensorDictBase: def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: - shape = torch.Size([]) + shape = _size([]) _dict = {} for key, item in self.items(): if item is not None: @@ -4453,7 +4455,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: # TensorDict requirements return TensorDict._new_unsafe( _dict, - batch_size=torch.Size([*shape, *self.shape]), + batch_size=_size([*shape, *self.shape]), device=self._device, ) @@ -4621,7 +4623,7 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: - shape = torch.Size([]) + shape = _size([]) try: device = self.device except RuntimeError: @@ -4632,7 +4634,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: for key in self.keys(True) if isinstance(key, str) and self[key] is not None }, - torch.Size([*shape, *self._safe_shape]), + _size([*shape, *self._safe_shape]), device=device, ) @@ -5078,7 +5080,7 @@ def shape(self): if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) - return torch.Size(shape) + return _size(shape) def expand(self, *shape): if len(shape) == 1 and not isinstance(shape[0], (int,)): @@ -5279,7 +5281,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if dim is None: if len(shape) == 1 or shape.count(1) == 0: return None - new_shape = torch.Size([s for s in shape if s != 1]) + new_shape = _size([s for s in shape if s != 1]) else: if dim < 0: dim += len(shape) @@ -5287,7 +5289,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if shape[dim] != 1: return None - new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim]) + new_shape = _size([s for i, s in enumerate(shape) if i != dim]) return new_shape @@ -5303,7 +5305,7 @@ def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: new_shape = list(shape) new_shape.insert(dim, 1) - return torch.Size(new_shape) + return _size(new_shape) class _CompositeSpecItemsView: @@ -5451,7 +5453,7 @@ def _remove_neg_shapes(*shape): if isinstance(shape, np.integer): shape = (int(shape),) return _remove_neg_shapes(*shape) - return torch.Size([int(d) if d >= 0 else 1 for d in shape]) + return _size([int(d) if d >= 0 else 1 for d in shape]) ############## diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 34a1d61bfc5..efa2fcfb270 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -39,9 +39,13 @@ unravel_key, unravel_key_list, ) -from tensordict._C import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase -from tensordict.utils import expand_as_right, expand_right, NestedKey +from tensordict.utils import ( + _unravel_key_to_tuple, + expand_as_right, + expand_right, + NestedKey, +) from torch import nn, Tensor from torch.utils._pytree import tree_map diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 944e51f0b9e..71fee70d5b8 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -5,13 +5,21 @@ from __future__ import annotations import warnings +import weakref from numbers import Number from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import torch from torch import distributions as D, nn + +try: + from torch.compiler import assume_constant_result +except ImportError: + from torch._dynamo import assume_constant_result + from torch.distributions import constraints +from torch.distributions.transforms import _InverseTransform from torchrl.modules.distributions.truncated_normal import ( TruncatedNormal as _TruncatedNormal, @@ -20,8 +28,8 @@ from torchrl.modules.distributions.utils import ( _cast_device, FasterTransformedDistribution, - safeatanh, - safetanh, + safeatanh_noeps, + safetanh_noeps, ) from torchrl.modules.utils import mappings @@ -92,19 +100,21 @@ class SafeTanhTransform(D.TanhTransform): """TanhTransform subclass that ensured that the transformation is numerically invertible.""" def _call(self, x: torch.Tensor) -> torch.Tensor: - if x.dtype.is_floating_point: - eps = torch.finfo(x.dtype).resolution - else: - raise NotImplementedError(f"No tanh transform for {x.dtype} inputs.") - return safetanh(x, eps) + return safetanh_noeps(x) def _inverse(self, y: torch.Tensor) -> torch.Tensor: - if y.dtype.is_floating_point: - eps = torch.finfo(y.dtype).resolution - else: - raise NotImplementedError(f"No inverse tanh for {y.dtype} inputs.") - x = safeatanh(y, eps) - return x + return safeatanh_noeps(y) + + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + if not torch.compiler.is_dynamo_compiling(): + self._inv = weakref.ref(inv) + return inv class NormalParamWrapper(nn.Module): @@ -316,6 +326,33 @@ def log_prob(self, value, **kwargs): return lp +class _PatchedComposeTransform(D.ComposeTransform): + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)]) + if not torch.compiler.is_dynamo_compiling(): + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + return inv + + +class _PatchedAffineTransform(D.AffineTransform): + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + if not torch.compiler.is_dynamo_compiling(): + self._inv = weakref.ref(inv) + return inv + + class TanhNormal(FasterTransformedDistribution): """Implements a TanhNormal distribution with location scaling. @@ -344,6 +381,8 @@ class TanhNormal(FasterTransformedDistribution): as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc. tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; + safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows. + This will currently break with :func:`torch.compile`. """ arg_constraints = { @@ -369,6 +408,7 @@ def __init__( high: Union[torch.Tensor, Number] = 1.0, event_dims: int | None = None, tanh_loc: bool = False, + safe_tanh: bool = True, **kwargs, ): if "max" in kwargs: @@ -419,13 +459,22 @@ def __init__( self.low = low self.high = high - t = SafeTanhTransform() + if safe_tanh: + if torch.compiler.is_dynamo_compiling(): + _err_compile_safetanh() + t = SafeTanhTransform() + else: + t = D.TanhTransform() # t = D.TanhTransform() - if self.non_trivial_max or self.non_trivial_min: - t = D.ComposeTransform( + if torch.compiler.is_dynamo_compiling() or ( + self.non_trivial_max or self.non_trivial_min + ): + t = _PatchedComposeTransform( [ t, - D.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2), + _PatchedAffineTransform( + loc=(high + low) / 2, scale=(high - low) / 2 + ), ] ) self._t = t @@ -446,7 +495,9 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale # loc must be rescaled if tanh_loc - if self.non_trivial_max or self.non_trivial_min: + if torch.compiler.is_dynamo_compiling() or ( + self.non_trivial_max or self.non_trivial_min + ): loc = loc + (self.high - self.low) / 2 + self.low self.loc = loc self.scale = scale @@ -466,6 +517,10 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: base = D.Normal(self.loc, self.scale) super().__init__(base, self._t) + @property + def support(self): + return D.constraints.real() + @property def root_dist(self): bd = self @@ -696,10 +751,10 @@ def __init__( loc = self.update(param) if self.non_trivial: - t = D.ComposeTransform( + t = _PatchedComposeTransform( [ t, - D.AffineTransform( + _PatchedAffineTransform( loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2 ), ] @@ -761,3 +816,16 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: uniform_sample_delta = _uniform_sample_delta + + +def _err_compile_safetanh(): + raise RuntimeError( + "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" + "safe_tanh=False. " + "If you are using a ProbabilisticTensorDictModule, this can be done via " + "`distribution_kwargs={'safe_tanh': False}`. " + "See https://github.com/pytorch/pytorch/issues/133529 for more details." + ) + + +_warn_compile_safetanh = assume_constant_result(_err_compile_safetanh) diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 267632c4fd9..546d93cb228 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -6,7 +6,6 @@ from typing import Union import torch -from packaging import version from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution @@ -92,72 +91,133 @@ def __init__(self, base_distribution, transforms, validate_args=None): ) -if version.parse(torch.__version__) >= version.parse("2.0.0"): - - class _SafeTanh(autograd.Function): - generate_vmap_rule = True - - @staticmethod - def forward(input, eps): - output = input.tanh() - lim = 1.0 - eps - output = output.clamp(-lim, lim) - # ctx.save_for_backward(output) - return output - - @staticmethod - def setup_context(ctx, inputs, output): - # input, eps = inputs - # ctx.mark_non_differentiable(ind, ind_inv) - # # Tensors must be saved via ctx.save_for_backward. Please do not - # # assign them directly onto the ctx object. - ctx.save_for_backward(output) - - @staticmethod - def backward(ctx, *grad): - grad = grad[0] - (output,) = ctx.saved_tensors - return (grad * (1 - output.pow(2)), None) - - class _SafeaTanh(autograd.Function): - generate_vmap_rule = True - - @staticmethod - def setup_context(ctx, inputs, output): - tanh_val, eps = inputs - # ctx.mark_non_differentiable(ind, ind_inv) - # # Tensors must be saved via ctx.save_for_backward. Please do not - # # assign them directly onto the ctx object. - ctx.save_for_backward(tanh_val) - ctx.eps = eps - - @staticmethod - def forward(tanh_val, eps): - lim = 1.0 - eps - output = tanh_val.clamp(-lim, lim) - # ctx.save_for_backward(output) - output = output.atanh() - return output - - @staticmethod - def backward(ctx, *grad): - grad = grad[0] - (tanh_val,) = ctx.saved_tensors - eps = ctx.eps - lim = 1.0 - eps - output = tanh_val.clamp(-lim, lim) - return (grad / (1 - output.pow(2)), None) - - safetanh = _SafeTanh.apply - safeatanh = _SafeaTanh.apply - -else: - - def safetanh(x, eps): # noqa: D103 +def _safetanh(x, eps): # noqa: D103 + lim = 1.0 - eps + y = x.tanh() + return y.clamp(-lim, lim) + + +def _safeatanh(y, eps): # noqa: D103 + lim = 1.0 - eps + return y.clamp(-lim, lim).atanh() + + +class _SafeTanh(autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(input, eps): + output = input.tanh() lim = 1.0 - eps - y = x.tanh() - return y.clamp(-lim, lim) + output = output.clamp(-lim, lim) + # ctx.save_for_backward(output) + return output + + @staticmethod + def setup_context(ctx, inputs, output): + # input, eps = inputs + # ctx.mark_non_differentiable(ind, ind_inv) + # # Tensors must be saved via ctx.save_for_backward. Please do not + # # assign them directly onto the ctx object. + ctx.save_for_backward(output) + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (output,) = ctx.saved_tensors + return (grad * (1 - output.pow(2)), None) + + +class _SafeTanhNoEps(autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(input): + output = input.tanh() + eps = torch.finfo(input.dtype).resolution + lim = 1.0 - eps + output = output.clamp(-lim, lim) + return output + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(output) + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (output,) = ctx.saved_tensors + return (grad * (1 - output.pow(2)),) + + +class _SafeaTanh(autograd.Function): + generate_vmap_rule = True - def safeatanh(y, eps): # noqa: D103 + @staticmethod + def forward(tanh_val, eps): + if eps is None: + eps = torch.finfo(tanh_val.dtype).resolution lim = 1.0 - eps - return y.clamp(-lim, lim).atanh() + output = tanh_val.clamp(-lim, lim) + # ctx.save_for_backward(output) + output = output.atanh() + return output + + @staticmethod + def setup_context(ctx, inputs, output): + tanh_val, eps = inputs + + # ctx.mark_non_differentiable(ind, ind_inv) + # # Tensors must be saved via ctx.save_for_backward. Please do not + # # assign them directly onto the ctx object. + ctx.save_for_backward(tanh_val) + ctx.eps = eps + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (tanh_val,) = ctx.saved_tensors + eps = ctx.eps + lim = 1.0 - eps + output = tanh_val.clamp(-lim, lim) + return (grad / (1 - output.pow(2)), None) + + +class _SafeaTanhNoEps(autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(tanh_val): + eps = torch.finfo(tanh_val.dtype).resolution + lim = 1.0 - eps + output = tanh_val.clamp(-lim, lim) + # ctx.save_for_backward(output) + output = output.atanh() + return output + + @staticmethod + def setup_context(ctx, inputs, output): + tanh_val = inputs[0] + eps = torch.finfo(tanh_val.dtype).resolution + + # ctx.mark_non_differentiable(ind, ind_inv) + # # Tensors must be saved via ctx.save_for_backward. Please do not + # # assign them directly onto the ctx object. + ctx.save_for_backward(tanh_val) + ctx.eps = eps + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (tanh_val,) = ctx.saved_tensors + eps = ctx.eps + lim = 1.0 - eps + output = tanh_val.clamp(-lim, lim) + return (grad / (1 - output.pow(2)),) + + +safetanh = _SafeTanh.apply +safeatanh = _SafeaTanh.apply + +safetanh_noeps = _SafeTanhNoEps.apply +safeatanh_noeps = _SafeaTanhNoEps.apply diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 34c62bc3260..c823788b4c2 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -61,8 +61,9 @@ class A2CLoss(LossModule): ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coef (float): the weight of the entropy loss. - critic_coef (float): the weight of the critic loss. + entropy_coef (float): the weight of the entropy loss. Defaults to `0.01``. + critic_coef (float): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic + loss won't be included and the in-keys will miss the critic inputs. loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. separate_losses (bool, optional): if ``True``, shared parameters between @@ -323,7 +324,13 @@ def __init__( self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) ) - self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device)) + if critic_coef is not None: + self.register_buffer( + "critic_coef", torch.as_tensor(critic_coef, device=device) + ) + else: + self.critic_coef = None + if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type @@ -356,7 +363,7 @@ def in_keys(self): *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], ] - if self.critic_coef: + if self.critic_coef is not None: keys.extend(self.critic_network.in_keys) return list(set(keys)) @@ -364,7 +371,7 @@ def in_keys(self): def out_keys(self): if self._out_keys is None: outs = ["loss_objective"] - if self.critic_coef: + if self.critic_coef is not None: outs.append("loss_critic") if self.entropy_bonus: outs.append("entropy") @@ -430,7 +437,12 @@ def _log_probs( log_prob = log_prob.unsqueeze(-1) return log_prob, dist - def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: + def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: + """Returns the loss value of the critic, multiplied by ``critic_coef`` if it is not ``None``. + + Returns the loss and the clip-fraction. + + """ if self.clip_value: old_state_value = tensordict.get( self.tensor_keys.value, None @@ -480,7 +492,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: loss_value, self.loss_critic_type, ) - return self.critic_coef * loss_value, clip_fraction + if self.critic_coef is not None: + return self.critic_coef * loss_value, clip_fraction + return loss_value, clip_fraction @property @_cache_values @@ -507,7 +521,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = self.get_entropy_bonus(dist) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 5ceec84e36a..cd4e47ef336 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -15,6 +15,7 @@ from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams +from tensordict.utils import Buffer from torch import nn from torch.nn import Parameter from torchrl._utils import RL_WARNINGS @@ -23,9 +24,18 @@ from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase +try: + from torch.compiler import is_dynamo_compiling +except ModuleNotFoundError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _updater_check_forward_prehook(module, *args, **kwargs): - if not all(module._has_update_associated.values()) and RL_WARNINGS: + if ( + not all(module._has_update_associated.values()) + and RL_WARNINGS + and not is_dynamo_compiling() + ): warnings.warn( module.TARGET_NET_WARNING, category=UserWarning, @@ -217,8 +227,10 @@ def set_keys(self, **kwargs) -> None: >>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value") """ for key, value in kwargs.items(): - if key not in self._AcceptedKeys.__dict__: - raise ValueError(f"{key} is not an accepted tensordict key") + if key not in self._AcceptedKeys.__dataclass_fields__: + raise ValueError( + f"{key} is not an accepted tensordict key. Accepted keys are: {self._AcceptedKeys.__dataclass_fields__}." + ) if value is not None: setattr(self.tensor_keys, key, value) else: @@ -415,7 +427,11 @@ def __getattr__(self, item): # no target param, take detached data params = getattr(self, item[7:]) params = params.data - elif not self._has_update_associated[item[7:-7]] and RL_WARNINGS: + elif ( + not self._has_update_associated[item[7:-7]] + and RL_WARNINGS + and not is_dynamo_compiling() + ): # no updater associated warnings.warn( self.TARGET_NET_WARNING, @@ -433,7 +449,7 @@ def _apply(self, fn): def _erase_cache(self): for key in list(self.__dict__): if key.startswith("_cache"): - del self.__dict__[key] + delattr(self, key) def _networks(self) -> Iterator[nn.Module]: for item in self.__dir__(): @@ -603,11 +619,10 @@ def __init__(self, clone): self.clone = clone def __call__(self, x): + x = x.data.clone() if self.clone else x.data if isinstance(x, nn.Parameter): - return nn.Parameter( - x.data.clone() if self.clone else x.data, requires_grad=False - ) - return x.data.clone() if self.clone else x.data + return Buffer(x) + return x def add_ramdom_module(module): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f7582fb5892..fb8fbff2ccf 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -521,16 +521,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - td_device = tensordict_reshape.to(tensordict.device) - - q_loss, metadata = self.q_loss(td_device) - cql_loss, cql_metadata = self.cql_loss(td_device) + q_loss, metadata = self.q_loss(tensordict_reshape) + cql_loss, cql_metadata = self.cql_loss(tensordict_reshape) if self.with_lagrange: - alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(td_device) + alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss( + tensordict_reshape + ) metadata.update(alpha_prime_metadata) - loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device) - loss_actor, actor_metadata = self.actor_loss(td_device) - loss_alpha, alpha_metadata = self.alpha_loss(td_device) + loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape) + loss_actor, actor_metadata = self.actor_loss(tensordict_reshape) + loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata) metadata.update(bc_metadata) metadata.update(cql_metadata) metadata.update(actor_metadata) @@ -547,7 +547,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_cql": cql_loss, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(), + "entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(), } if self.with_lagrange: out["loss_alpha_prime"] = alpha_prime_loss.mean() @@ -574,7 +574,7 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor: metadata = {"bc_log_prob": bc_log_prob.mean().detach()} return bc_actor_loss, metadata - def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -585,6 +585,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: log_prob = dist.log_prob(a_reparm) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + if td_q is tensordict: + raise RuntimeError td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qvalue_networkN0( td_q, @@ -599,12 +601,12 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" ) - # write log_prob in tensordict for alpha loss - tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) + metadata = {} + metadata[self.tensor_keys.log_prob] = log_prob.detach() actor_loss = self._alpha * log_prob - min_q_logprob actor_loss = _reduce(actor_loss, reduction=self.reduction) - return actor_loss, {} + return actor_loss, metadata def _get_policy_actions(self, data, actor_params, num_actions=10): batch_size = data.batch_size @@ -667,7 +669,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): if self.max_q_backup: next_tensordict, _ = self._get_policy_actions( - tensordict.get("next"), + tensordict.get("next").copy(), actor_params, num_actions=self.num_random, ) @@ -691,10 +693,10 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def q_loss(self, tensordict: TensorDictBase) -> Tensor: + def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._get_value_v( - tensordict, + tensordict.copy(), self._alpha, self.actor_network_params, self.target_qvalue_network_params, @@ -722,7 +724,7 @@ def q_loss(self, tensordict: TensorDictBase) -> Tensor: metadata = {"td_error": td_error.detach()} return loss_qval, metadata - def cql_loss(self, tensordict: TensorDictBase) -> Tensor: + def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: pred_q1 = tensordict.get(self.tensor_keys.pred_q1) pred_q2 = tensordict.get(self.tensor_keys.pred_q2) @@ -746,12 +748,12 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor: .to(tensordict.device) ) curr_actions_td, curr_log_pis = self._get_policy_actions( - tensordict, + tensordict.copy(), self.actor_network_params, num_actions=self.num_random, ) new_curr_actions_td, new_log_pis = self._get_policy_actions( - tensordict.get("next"), + tensordict.get("next").copy(), self.actor_network_params, num_actions=self.num_random, ) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index a4e241347e2..c4639b70bdd 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -383,7 +383,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor, metadata = self.actor_loss(tensordict_reshape) loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape) loss_value, metadata_value = self.value_loss(tensordict_reshape) - metadata.update(**metadata_qvalue, **metadata_value) + metadata.update(metadata_qvalue) + metadata.update(metadata_value) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape @@ -410,7 +411,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: [], ) - def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -446,7 +447,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) @@ -460,7 +461,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: value_loss = _reduce(value_loss, reduction=self.reduction) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys tensordict = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False @@ -781,7 +782,7 @@ def __init__( self.action_space = _find_action_space(action_space) self.reduction = reduction - def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -828,7 +829,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value with torch.no_grad(): # Min Q value @@ -856,7 +857,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: value_loss = _reduce(value_loss, reduction=self.reduction) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys next_td = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 9d9790ab294..efc951b3999 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -6,7 +6,6 @@ import contextlib -import math from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -80,7 +79,8 @@ class PPOLoss(LossModule): entropy_coef (scalar, optional): entropy multiplier when computing the total loss. Defaults to ``0.01``. critic_coef (scalar, optional): critic loss multiplier when computing the total - loss. Defaults to ``1.0``. + loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value + loss from the forward outputs. loss_critic_type (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized @@ -371,7 +371,12 @@ def __init__( device = torch.device("cpu") self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) - self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device)) + if critic_coef is not None: + self.register_buffer( + "critic_coef", torch.tensor(critic_coef, device=device) + ) + else: + self.critic_coef = None self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage if gamma is not None: @@ -504,6 +509,7 @@ def _log_weight( return log_weight, dist, kl_approx def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: + """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``.""" # TODO: if the advantage is gathered by forward, this introduces an # overhead that we could easily reduce. if self.separate_losses: @@ -562,7 +568,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: self.loss_critic_type, ) - return self.critic_coef * loss_value, clip_fraction + if self.critic_coef is not None: + return self.critic_coef * loss_value, clip_fraction + return loss_value, clip_fraction @property @_cache_values @@ -595,7 +603,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: @@ -679,7 +687,8 @@ class ClipPPOLoss(PPOLoss): entropy_coef (scalar, optional): entropy multiplier when computing the total loss. Defaults to ``0.01``. critic_coef (scalar, optional): critic loss multiplier when computing the total - loss. Defaults to ``1.0``. + loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value + loss from the forward outputs. loss_critic_type (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized @@ -800,13 +809,18 @@ def __init__( clip_value=clip_value, **kwargs, ) - self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon)) + for p in self.parameters(): + device = p.device + break + else: + device = None + self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) @property def _clip_bounds(self): return ( - math.log1p(-self.clip_epsilon), - math.log1p(self.clip_epsilon), + (-self.clip_epsilon).log1p(), + self.clip_epsilon.log1p(), ) @property @@ -869,7 +883,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: @@ -1163,7 +1177,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index cda2c62894e..271f233bae8 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -326,7 +326,11 @@ def __init__( else: self.register_parameter( "log_alpha", - torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + torch.nn.Parameter( + torch.tensor( + math.log(alpha_init), device=device, requires_grad=True + ) + ), ) self._target_entropy = target_entropy @@ -401,10 +405,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: @property def alpha(self): - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): - alpha = self.log_alpha.exp() - return alpha + return self.log_alpha.clamp(self.min_log_alpha, self.max_log_alpha).exp() def _set_in_keys(self): keys = [ @@ -448,9 +450,12 @@ def _qvalue_params_cat(self, selected_q_params): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys - tensordict_select = tensordict.clone(False).select( + tensordict_select = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False ) + # We need to copy bc select does not copy sub-tds + tensordict_select = tensordict_select.copy() + selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len ].sort()[0] @@ -467,7 +472,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: *self.actor_network.in_keys, strict=False ) # next_observation -> tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - # tensordict_actor = tensordict_actor.contiguous() with set_exploration_type(ExplorationType.RANDOM): if self.gSDE: @@ -480,19 +484,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_actor, actor_params, ) - if isinstance(self.actor_network, TensorDictSequential): - sample_key = self.tensor_keys.action - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - else: - sample_key = self.tensor_keys.action - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) + sample_key = self.tensor_keys.action + sample_key_lp = self.tensor_keys.sample_log_prob + tensordict_actor_dist = self.actor_network.build_dist_from_params(td_params) tensordict_actor.set(sample_key, tensordict_actor_dist.rsample()) tensordict_actor.set( - self.tensor_keys.sample_log_prob, + sample_key_lp, tensordict_actor_dist.log_prob(tensordict_actor.get(sample_key)), ) @@ -603,12 +600,22 @@ def _loss_alpha(self, log_pi: Tensor) -> Tensor: ) if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + alpha_loss = -self._safe_log_alpha.exp() * ( + log_pi.detach() + self.target_entropy + ) else: # placeholder alpha_loss = torch.zeros_like(log_pi) return alpha_loss + @property + def _safe_log_alpha(self): + log_alpha = self.log_alpha + with torch.no_grad(): + log_alpha_clamp = log_alpha.clamp(self.min_log_alpha, self.max_log_alpha) + log_alpha_det = log_alpha.detach() + return log_alpha - log_alpha_det + log_alpha_clamp + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 922d6df7a74..89ff581991f 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -372,7 +372,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict): + def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -398,7 +398,7 @@ def actor_loss(self, tensordict): loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata - def value_loss(self, tensordict): + def value_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index cd40ac1e029..8b394137480 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -386,7 +386,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict): + def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. @@ -433,7 +433,7 @@ def actor_loss(self, tensordict): loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata - def qvalue_loss(self, tensordict): + def qvalue_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 3031763c50f..66eae215e54 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -454,11 +454,17 @@ def next_state_value( return target_value -def _cache_values(fun): +def _cache_values(func): """Caches the tensordict returned by a property.""" - name = fun.__name__ + name = func.__name__ - def new_fun(self, netname=None): + @functools.wraps(func) + def new_func(self, netname=None): + if torch.compiler.is_dynamo_compiling(): + if netname is not None: + return func(self, netname) + else: + return func(self) __dict__ = self.__dict__ _cache = __dict__.setdefault("_cache", {}) attr_name = name @@ -468,16 +474,16 @@ def new_fun(self, netname=None): out = _cache[attr_name] return out if netname is not None: - out = fun(self, netname) + out = func(self, netname) else: - out = fun(self) + out = func(self) # TODO: decide what to do with locked tds in functional calls # if is_tensor_collection(out): # out.lock_() _cache[attr_name] = out return out - return new_fun + return new_func def _vmap_func(module, *args, func=None, **kwargs):