Skip to content

[BUG] TensorDictPrimer overwrites nested specs in the environemnt  #2331

Closed
@matteobettini

Description

part of #2327

The primer overwrites any nested spec

Consider an env with nested specs

 env = VmasEnv(
        scenario="balance,
        num_envs=5,
    )

add to it a primer for a nested hidden state

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
                "agents": CompositeSpec(
                    {
                        "h": UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                        )
                    },
                    shape=(*env.shape, env.n_agents),
                )
            }
        ),
    )

the primer code in

observation_spec[key] = self.primers[key] = spec.to(device)
will overwirite the observation spec instead of updating it, resulting in the loss of all the spec keys that previoulsy were in the "agents" spec

The same result is obtained with

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
               ("agents","h"): UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                )
            }
        ),
    )

here, updating the spec instead of overwriting it should do the job

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions