Skip to content

Commit

Permalink
Efficiency improvement in VecNorm, TensorDict and env tensor casting (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 26, 2022
1 parent 476ca4e commit f0ab441
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 193 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ channels:
dependencies:
- pip
- cmake >= 3.18
- protobuf
- pip:
- hypothesis
- protobuf
- future
- cloudpickle
- gym_retro
Expand Down
1 change: 0 additions & 1 deletion .circleci/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ dependencies:
- cmake >= 3.18
- pip:
- hypothesis
- protobuf
- future
- cloudpickle
- pytest
Expand Down
2 changes: 1 addition & 1 deletion .circleci/unittest/linux_stable/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ dependencies:
- pip
- ninja
- cmake >= 3.18
- protobuf
- pip:
- hypothesis
- protobuf
- future
- cloudpickle
- gym_retro
Expand Down
9 changes: 9 additions & 0 deletions test/smoke_test_deps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import tempfile

from torch.utils.tensorboard import SummaryWriter
from torchrl.envs import DMControlEnv, GymEnv


Expand All @@ -20,3 +23,9 @@ def test_gym():
def test_gym_pixels():
env = GymEnv("ALE/Pong-v5", from_pixels=True)
env.reset()


def test_tb():
with tempfile.TemporaryDirectory() as directory:
writer = SummaryWriter(log_dir=directory)
writer.add_scalar("a", 1, 1)
4 changes: 2 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,14 +600,14 @@ def test_masking(self, td_name):
@pytest.mark.skipif(
torch.cuda.device_count() == 0, reason="No cuda device detected"
)
@pytest.mark.parametrize("device", [0, "cuda:0", "cuda", torch.device("cuda:0")])
@pytest.mark.parametrize("device", [0, "cuda:0", torch.device("cuda:0")])
def test_pin_memory(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)
if td_name != "saved_td":
td.pin_memory()
td_device = td.to(device)
_device = torch.device("cuda:0")
_device = torch.device(device)
assert td_device.device == _device
assert td_device.clone().device == _device
assert td_device is not td
Expand Down
12 changes: 10 additions & 2 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@

import pytest
import torch
from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter

try:
from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter

_has_tb = True
except ImportError:
_has_tb = False

from torchrl.data import (
TensorDict,
TensorDictPrioritizedReplayBuffer,
Expand Down Expand Up @@ -214,6 +221,7 @@ def test_subsampler():


@pytest.mark.skipif(not _has_gym, reason="No gym library")
@pytest.mark.skipif(not _has_tb, reason="No tensorboard library")
def test_recorder():
with tempfile.TemporaryDirectory() as folder:
writer = SummaryWriter(log_dir=folder)
Expand Down
8 changes: 6 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._REG[self.name][2] = count + 1

@staticmethod
def print():
def print(prefix=None):
keys = list(timeit._REG)
keys.sort()
for name in keys:
print(
strings = []
if prefix:
strings.append(prefix)
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
print(" -- ".join(strings))
35 changes: 20 additions & 15 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from textwrap import indent
from typing import (
Expand Down Expand Up @@ -210,24 +209,28 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""
if not isinstance(val, torch.Tensor):
try:
val = torch.tensor(val, dtype=self.dtype)
except ValueError:
val = torch.tensor(deepcopy(val), dtype=self.dtype)
if isinstance(val, np.ndarray) and not all(
stride > 0 for stride in val.strides
):
val = val.copy()
val = torch.as_tensor(val, dtype=self.dtype, device=self.device)
self.assert_is_in(val)
return val

def to_numpy(self, val: torch.Tensor) -> np.ndarray:
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
"""Returns the np.ndarray correspondent of an input tensor.
Args:
val (torch.Tensor): tensor to be transformed_in to numpy
safe (bool): boolean value indicating whether a check should be
performed on the value against the domain of the spec.
Returns:
a np.ndarray
"""
self.assert_is_in(val)
if safe:
self.assert_is_in(val)
return val.detach().cpu().numpy()

def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -490,9 +493,8 @@ def encode(
space: Optional[DiscreteBox] = None,
) -> torch.Tensor:
if not isinstance(val, torch.Tensor):
val = torch.tensor(val)
val = torch.as_tensor(val, dtype=self.dtype, device=self.device)

val = torch.tensor(val, dtype=torch.long)
if space is None:
space = self.space

Expand All @@ -504,10 +506,11 @@ def encode(
val = torch.nn.functional.one_hot(val, space.n).to(torch.long)
return val

def to_numpy(self, val: torch.Tensor) -> np.ndarray:
if not isinstance(val, torch.Tensor):
raise NotImplementedError
self.assert_is_in(val)
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
if safe:
if not isinstance(val, torch.Tensor):
raise NotImplementedError
self.assert_is_in(val)
val = val.argmax(-1).cpu().numpy()
if self.use_register:
inv_reg = self.space.register.inverse()
Expand Down Expand Up @@ -794,7 +797,7 @@ def rand(self, shape: torch.Size = torch.Size([])) -> torch.Tensor:

def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if not isinstance(val, torch.Tensor):
val = torch.tensor(val)
val = torch.tensor(val, device=self.device)

x = []
for v, space in zip(val.unbind(-1), self.space):
Expand All @@ -809,7 +812,9 @@ def _split(self, val: torch.Tensor) -> torch.Tensor:
vals = val.split([space.n for space in self.space], dim=-1)
return vals

def to_numpy(self, val: torch.Tensor) -> np.ndarray:
def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray:
if safe:
self.assert_is_in(val)
vals = self._split(val)
out = torch.stack([val.argmax(-1) for val in vals], -1).numpy()
return out
Expand Down
Loading

0 comments on commit f0ab441

Please sign in to comment.