Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Calculation of GAE fails with recurrent critic #2372

Closed
3 tasks done
thomasbbrunner opened this issue Aug 6, 2024 · 1 comment · Fixed by #2376
Closed
3 tasks done

[BUG] Calculation of GAE fails with recurrent critic #2372

thomasbbrunner opened this issue Aug 6, 2024 · 1 comment · Fixed by #2376
Assignees
Labels
bug Something isn't working

Comments

@thomasbbrunner
Copy link
Contributor

Describe the bug

The calculation of the GAE with a recurrent critic fails with the error:

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

It seems that setting the flag shifted to True prevents this error.

Is this behavior expected? If so, should we maybe document that the shifted flag is necessary for recurrent critics?

To Reproduce

Minimal snippet to reproduce the issue:

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import GymEnv, TransformedEnv, transforms
from torchrl.envs.utils import check_env_specs
from torchrl.modules import LSTMModule
from torchrl.objectives.value import GAE

env = GymEnv(env_name="HalfCheetah-v4", device="cpu")
env = TransformedEnv(env)
env.append_transform(transforms.DoubleToFloat(in_keys=["observation"]))
env.append_transform(transforms.InitTracker())
env.append_transform(
    transforms.TensorDictPrimer(
        {
            "recurrent_state_h": UnboundedContinuousTensorSpec(shape=(1, 128)),
            "recurrent_state_c": UnboundedContinuousTensorSpec(shape=(1, 128)),
        }
    )
)
check_env_specs(env)

observation_size = env.observation_spec["observation"].shape[-1]
action_size = env.action_spec.shape[-1]

rnn = LSTMModule(
    input_size=observation_size,
    hidden_size=128,
    num_layers=1,
    device="cpu",
    in_key="observation",
    out_key="features",
)

value_net = TensorDictModule(
    module=nn.Sequential(
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
    ),
    in_keys=["features"],
    out_keys=["state_value"],
)
critic_module = TensorDictSequential(rnn, value_net)

collector = SyncDataCollector(
    env,
    None,
    frames_per_batch=512,
    device="cpu",
)

batch = collector.next()

# With shifted=True calculation of advantages works
advantage_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic_module, shifted=True)
with torch.no_grad():
    advantage_module(batch)

# With shifted=False calculation of advantage fails!
advantage_module = GAE(gamma=0.99, lmbda=0.95, value_network=critic_module, shifted=False)
with torch.no_grad():
    # NOTE: Should raise
    # RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow (...)
    advantage_module(batch)

System info

> pip list | grep torch
torch                          2.4.0
torchrl                        0.5.0
  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@thomasbbrunner thomasbbrunner added the bug Something isn't working label Aug 6, 2024
@vmoens vmoens linked a pull request Aug 7, 2024 that will close this issue
@vmoens
Copy link
Contributor

vmoens commented Aug 7, 2024

#2376 should fix it.
You'll still need to add python_based=True in your LSTMModule

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants