Skip to content

Bug in recorder rollout when using lstm #170

Closed
@idanshen

Description

Hi,

First of all, thank you for opening this library to the public, personally, I find it a very convenient framework.
When running PPO with the lstm option enabled, I get the following error, no matter which env I'm running:
File "/mnt/data/usr/code/TSRL/ppo.py", line 174, in <module> main(args) File "/mnt/data/usr/code/TSRL/ppo.py", line 168, in main trainer.train() File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/trainers/trainers.py", line 363, in train self._post_steps_log_hook(batch) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/trainers/trainers.py", line 340, in _post_steps_log_hook result = op(batch, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/trainers/trainers.py", line 849, in __call__ td_record = self.recorder.rollout( File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/envs/common.py", line 451, in rollout tensordict = policy(tensordict) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/td_module/probabilistic.py", line 253, in forward dist, tensordict_out = self.get_dist( File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/td_module/probabilistic.py", line 221, in get_dist tensordict_out = self._call_module( File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/td_module/probabilistic.py", line 193, in _call_module return self.module(tensordict, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/td_module/common.py", line 346, in forward tensors = self._call_module(tensors, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/td_module/common.py", line 336, in _call_module out = self.module(*tensors, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/distributions/continuous.py", line 143, in forward net_output = self.operator(*tensors) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/models/models.py", line 967, in forward return self._lstm(input, hidden0_in, hidden1_in) File "/home/usr/anaconda3/envs/torch_rl/lib/python3.9/site-packages/torchrl-0.1-py3.9-linux-x86_64.egg/torchrl/modules/models/models.py", line 910, in _lstm batch, steps = input.shape[:2] ValueError: not enough values to unpack (expected 2, got 1)

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions