-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] NonTensorSpec calls NonTensorData with unexpected shape argument #2168
Comments
Oh yeah that should be batch_size! |
This is already fixed on main rl/torchrl/data/tensor_specs.py Lines 1952 to 1959 in eaa3dd8
I can add the test though |
Hi @vmoens , I am afraid the fix on main...
does not work (unless I am calling things wrong in the first place): torchrl.data.NonTensorSpec().one(shape=[1]) Traceback (most recent call last):
File "/work/rleap1/michael.aichmueller/miniconda/envs/rgnn/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-37af7c334d8b>", line 3, in <module>
torchrl.data.NonTensorSpec().one(shape=[1])
File "/work/rleap1/michael.aichmueller/miniconda/envs/rgnn/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 1969, in one
return NonTensorData(
^^^^^^^^^^^^^^
File "/work/rleap1/michael.aichmueller/miniconda/envs/rgnn/lib/python3.11/site-packages/tensordict/tensorclass.py", line 382, in wrapper
batch_size=torch.Size(batch_size),
^^^^^^^^^^^^^^^^^^^^^^
TypeError: torch.Size() takes an iterable of 'int' (item 1 is 'torch.Size') |
@vmoens , thanks for your help on this^, ran into the same issue, am on the main branch now and still running into an issue. For a bit of context, my ._step(), generates a td with a field "x", in my custom env. "x" is itself a tensordict with a nested key structure, where the keys are not fixed (variable depending on the env). I was trying a super simple observation spec with a blank NonTensorSpec:
And when instantiating the SyncDataCollector I eventually run into : Just want to confirm this is a different issue. |
Interesting, that seems to be an issue with NonTensorData which receives smth unexpected during I ran this code and it works fine on my end (on main branch), can you help me figure out how your example differs? from torchrl.envs import GymEnv, Transform, check_env_specs
from torchrl.data import NonTensorSpec, TensorSpec
from torchrl.collectors import SyncDataCollector
from tensordict import TensorDictBase, TensorDict, NonTensorData
class AddNontTensorData(Transform):
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict["nt"] = "a string!"
return tensordict
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
return tensordict_reset.set("nt", NonTensorData("reset!"))
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
observation_spec["nt"] = NonTensorSpec(shape=())
return observation_spec
env = GymEnv("CartPole-v1").append_transform(AddNontTensorData())
print(env.rollout(3))
check_env_specs(env)
collector = SyncDataCollector(env, frames_per_batch=200)
for data in collector:
print(data)
break EDIT: the data from the collector is garbage but non tensor data is not yet supported by collectors "officially" |
Describe the bug
a NonTensorSpec apparently instantiates a NonTensorData object wrong on calls such as
one
,zero
, andrand
.To Reproduce
shows
This currently comes up when I try to build an env with a
NonTensorSpec
and runcheck_env_specs(env)
on it.Expected behavior
Not throwing an error.
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
Reason and Possible fixes
These lines seem to be at fault:
Checklist
The text was updated successfully, but these errors were encountered: