Skip to content

[BUG] NonTensorData disappears from environment after step_mdp()Β #2171

@Michael-C-Strobel

Description

Describe the bug

When calling step_mdp(), NonTensorData is removed from the next dictionary.

To Reproduce

from typing import Optional

import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from torch import nn

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec, NonTensorSpec
from torchrl.envs import (
    EnvBase,
)
from torchrl.envs.utils import check_env_specs, step_mdp

def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
  """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
  if batch_size is None:
      batch_size = []
  td = TensorDict(
      {
          "params": TensorDict(
              {
                  "max_speed": 8,
                  "max_torque": 2.0,
                  "dt": 0.05,
                  "g": g,
                  "m": 1.0,
                  "l": 1.0,
                  "my_non_tensor_spec": None
              },
              [],
          )
      },
      [],
  )
  if batch_size:
      td = td.expand(batch_size).contiguous()
  return td

def angle_normalize(x):
  return ((x + torch.pi) % (2 * torch.pi)) - torch.pi

def _make_spec(self, td_params):
  # Under the hood, this will populate self.output_spec["observation"]
  self.observation_spec = CompositeSpec(
      th=BoundedTensorSpec(
          low=-torch.pi,
          high=torch.pi,
          shape=(),
          dtype=torch.float32,
      ),
      thdot=BoundedTensorSpec(
          low=-td_params["params", "max_speed"],
          high=td_params["params", "max_speed"],
          shape=(),
          dtype=torch.float32,
      ),
      my_non_tensor_spec = NonTensorSpec(shape=()),
      # we need to add the ``params`` to the observation specs, as we want
      # to pass it at each step during a rollout
      params=make_composite_from_td(td_params["params"]),
      shape=(),
  )
  # since the environment is stateless, we expect the previous output as input.
  # For this, ``EnvBase`` expects some state_spec to be available
  self.state_spec = self.observation_spec.clone()
  # action-spec will be automatically wrapped in input_spec when
  # `self.action_spec = spec` will be called supported
  self.action_spec = BoundedTensorSpec(
      low=-td_params["params", "max_torque"],
      high=td_params["params", "max_torque"],
      shape=(1,),
      dtype=torch.float32,
  )
  self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))


def make_composite_from_td(td):
  # custom function to convert a ``tensordict`` in a similar spec structure
  # of unbounded values.
  composite = CompositeSpec(
      {
          key: make_composite_from_td(tensor)
          if isinstance(tensor, TensorDictBase)
          else UnboundedContinuousTensorSpec(
              dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
          )
          for key, tensor in td.items()
      },
      shape=td.shape,
  )
  return composite

class PendulumEnv(EnvBase):
  metadata = {
      "render_modes": ["human", "rgb_array"],
      "render_fps": 30,
  }
  batch_locked = False

  def __init__(self, td_params=None, seed=None, device="cpu"):
      if td_params is None:
          td_params = self.gen_params(batch_size=None)

      super().__init__(device=device, batch_size=[])
      self._make_spec(td_params)
      if seed is None:
          seed = torch.empty((), dtype=torch.int64).random_().item()
      self.set_seed(seed)

  # Helpers: _make_step and gen_params
  gen_params = staticmethod(gen_params)
  _make_spec = _make_spec

  def _set_seed(self, seed: Optional[int]):
      rng = torch.manual_seed(seed)
      self.rng = rng
  
  def _step(self, tensordict:TensorDict):
      th, thdot = tensordict["th"], tensordict["thdot"] # Angle, Angular Velocity
      tensordict['my_non_tensor_spec']
      g_force = tensordict["params", "g"]
      mass = tensordict["params", "m"]
      length = tensordict["params", "l"]
      dt = tensordict["params", "dt"]
      u = tensordict["action"].squeeze(-1)
      u = u.clamp(-tensordict["params", "max_torque"], tensordict["params", "max_torque"])
      costs = angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)
      
      new_thdot = (
          thdot
          + (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u) * dt
      )
      new_thdot = new_thdot.clamp(
          -tensordict["params", "max_speed"], tensordict["params", "max_speed"]
      )
      new_th = th + new_thdot * dt
      reward = -costs.view(*tensordict.shape, 1)
      done = torch.zeros_like(reward, dtype=torch.bool)
      out = TensorDict(
          {
              "th": new_th,
              "thdot": new_thdot,
              "params": tensordict["params"],
              "reward": reward,
              "done": done,
          },
          tensordict.shape,
      )
      return out
  
  def _reset(self, tensordict):
      if tensordict is None or tensordict.is_empty():
          # if no ``tensordict`` is passed, we generate a single set of hyperparameters
          # Otherwise, we assume that the input ``tensordict`` contains all the relevant
          # parameters to get started.
          tensordict = self.gen_params(batch_size=self.batch_size)

      high_th = torch.tensor(np.pi, device=self.device)
      high_thdot = torch.tensor(1.0, device=self.device)
      low_th = -high_th
      low_thdot = -high_thdot

      # for non batch-locked environments, the input ``tensordict`` shape dictates the number
      # of simulators run simultaneously. In other contexts, the initial
      # random state's shape will depend upon the environment batch-size instead.
      th = (
          torch.rand(tensordict.shape, generator=self.rng, device=self.device)
          * (high_th - low_th)
          + low_th
      )
      thdot = (
          torch.rand(tensordict.shape, generator=self.rng, device=self.device)
          * (high_thdot - low_thdot)
          + low_thdot
      )
      out = TensorDict(
          {
              "th": th,
              "thdot": thdot,
              "params": tensordict["params"],
              "my_non_tensor_spec": "Hello World"
          },
          batch_size=tensordict.shape,
      )
      return out
env = PendulumEnv()

data = env.reset()
print("Initial state")
print(data)
data['action'] = env.action_spec.rand()
data = env.step(data)
print("After step")
print(data)
data = step_mdp(data, keep_other=True)
print("After step_mdp")
print(data)

Output:

Initial state
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        my_non_tensor_spec: NonTensorData(data=Hello World, batch_size=torch.Size([]), device=None),
        params: TensorDict(
            fields={
                dt: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                l: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                m: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                max_speed: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                max_torque: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                my_non_tensor_spec: NonTensorData(data=None, batch_size=torch.Size([]), device=None)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        th: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        thdot: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

After step

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        my_non_tensor_spec: NonTensorData(data=Hello World, batch_size=torch.Size([]), device=None),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                params: TensorDict(
                    fields={
                        dt: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        l: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        m: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        max_speed: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                        max_torque: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                        my_non_tensor_spec: NonTensorData(data=None, batch_size=torch.Size([]), device=None)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                th: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                thdot: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        params: TensorDict(
            fields={
                dt: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                l: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                m: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                max_speed: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                max_torque: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                my_non_tensor_spec: NonTensorData(data=None, batch_size=torch.Size([]), device=None)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        th: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        thdot: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

After step_mdp

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        params: TensorDict(
            fields={
                dt: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                g: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                l: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                m: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                max_speed: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                max_torque: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        th: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        thdot: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Expected behavior

NonTensorData should be copied from next to the new TensorDict after step_mdp().

System info

Describe the characteristic of your environment:

  • Describe how the library was installed: pip
  • Python version 3.8.19
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.24.4 3.8.19 | packaged by conda-forge | (default, Mar 20 2024, 12:47:35) 
[GCC 12.3.0] linux

Reason and Possible fixes

It seems that in _set() NonTensorData is considered a tensor collection and the non_empty_local flag is never set to true. Therefore the data is never copied.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions