Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix LSTM use with padded/masked segments #1399

Merged
merged 3 commits into from
Jul 26, 2023

Conversation

smorad
Copy link
Contributor

@smorad smorad commented Jul 19, 2023

Description

LSTMModule does not work when using fixed-size padded/masked segments. This is likely due to view not making the underlying tensors contiguous like reshape does.

Motivation and Context

Example script which causes an error:

import torch
import tqdm
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch import nn
from torchrl.collectors import SyncDataCollector
import tensordict
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.envs import (
    Compose,
    ExplorationType,
    GrayScale,
    InitTracker,
    ObservationNorm,
    Resize,
    RewardScaling,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ConvNet, EGreedyWrapper, LSTMModule, MLP, QValueModule
from torchrl.objectives import DQNLoss, SoftUpdate


segment_length = 200
collection_length = segment_length * 2
utd = 16


device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu")
env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=False, device=device),
    Compose(
        StepCounter(),
        InitTracker(),
    ),
)
td = env.reset()
feature = Mod(nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 64)), in_keys=["observation"], out_keys=["embed"])
n_cells = feature(env.reset())["embed"].shape[-1]

lstm = LSTMModule(
    input_size=n_cells,
    hidden_size=128,
    device=device,
    in_key="embed",
    out_key="markov_state",
)
print("in_keys", lstm.in_keys)
print("out_keys", lstm.out_keys)

env.append_transform(lstm.make_tensordict_primer())
mlp = MLP(
    out_features=2,
    num_cells=[
        64,
    ],
    device=device,
)
mlp[-1].bias.data.fill_(0.0)
mlp = Mod(mlp, in_keys=["markov_state"], out_keys=["action_value"])

qval = QValueModule(action_space=env.action_spec)
stoch_policy = Seq(feature, lstm, mlp, qval)
stoch_policy = EGreedyWrapper(
    stoch_policy, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
)

policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)
policy(env.reset())

loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)
updater = SoftUpdate(loss_fn, eps=0.95)
optim = torch.optim.Adam(policy.parameters(), lr=3e-4)
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=collection_length, total_frames=1_000, split_trajs=True)
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(100), batch_size=4, prefetch=10
)
pbar = tqdm.tqdm(total=1_000_000)
longest = 0

traj_lens = []
for i, data in enumerate(collector):
    pbar.update(data.numel())
    # it is important to pass data that is not flattened
    padded = tensordict.pad(data, [0, 0, 0, segment_length-data.shape[-1]])
    rb.extend(padded)
    for _ in range(utd):
        batch = rb.sample().to(device)
        feature(batch)
        lstm(batch)
        masked = batch.masked_select(batch[('collector', 'mask')])
        loss_vals = loss_fn(masked)
        loss_vals["loss"].backward()
        optim.step()
        optim.zero_grad()
    longest = max(longest, data["step_count"].max().item())
    pbar.set_description(
        f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
    )
    stoch_policy.step(data.numel())
    updater.step()

    if i % 50 == 0:
        with set_exploration_type(ExplorationType.MODE), torch.no_grad():
            rollout = env.rollout(1000, stoch_policy)
            traj_lens.append(rollout.get(("next", "step_count")).max().item())

This produces the following error

Traceback (most recent call last):
  File "/Users/smorad/code/offline_rdqn/dqn.py", line 113, in <module>
    lstm(batch)
  File "/Users/smorad/miniforge3/envs/explore/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/nn/functional_modules.py", line 567, in new_fun
    return getattr(type(self), fun_name)(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/torchrl_src/torchrl/modules/tensordict_module/rnn.py", line 348, in forward
    tensordict.update(tensordict_shaped.reshape(shape))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/tensordict.py", line 2724, in reshape
    return self.reshape(*shape[0])
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/tensordict.py", line 2729, in reshape
    for key, item in self.items():
  File "/Users/smorad/code/explore/tensordict_src/tensordict/tensordict.py", line 1790, in items
    yield k, self._get_str(k, NO_DEFAULT)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/tensordict.py", line 7470, in _get_str
    tensor = self._source._get_str(key, default)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/tensordict.py", line 7473, in _get_str
    return self._transform_value(tensor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/smorad/code/explore/tensordict_src/tensordict/tensordict.py", line 7482, in _transform_value
    return getattr(item, self.custom_op)(**self._update_custom_op_kwargs(item))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

With this commit, the example script no longer fails.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 19, 2023
@vmoens vmoens changed the title Fix LSTM use with padded/masked segments [BugFix] Fix LSTM use with padded/masked segments Jul 20, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for taking care of this!

We'll take a small performance hit, and I wonder if some day we won't unadvertedly change it back to view to improve the runtime. A good way to ensure that this does not happen is to write a test that would fail otherwise. Is it something you can do? Basically, I think we just need to provide a tensordict that is already a view to the LSTM module and check that an error is raised if we use view instead of reshape.
A comment in the code would also help!

@smorad
Copy link
Contributor Author

smorad commented Jul 20, 2023

It was my understanding that reshape will only copy when necessary. So there should not be any performance decrease, right? Unless you are in my situation and it would've crashed otherwise.

@vmoens
Copy link
Contributor

vmoens commented Jul 20, 2023

It was my understanding that reshape will only copy when necessary. So there should not be any performance decrease, right? Unless you are in my situation and it would've crashed otherwise.

Oh yes, you're right!
reshape is fine then!

@vmoens vmoens merged commit dc8b7b5 into pytorch:main Jul 26, 2023
vmoens pushed a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants