[BUG] TensorDictPrimer
overwrites nested specs in the environemnt #2331
Closed
Description
part of #2327
The primer overwrites any nested spec
Consider an env with nested specs
env = VmasEnv(
scenario="balance,
num_envs=5,
)
add to it a primer for a nested hidden state
env = TransformedEnv(
env,
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
)
the primer code in
rl/torchrl/envs/transforms/transforms.py
Line 4649 in 0063741
The same result is obtained with
env = TransformedEnv(
env,
TensorDictPrimer(
{
("agents","h"): UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
}
),
)
here, updating the spec instead of overwriting it should do the job