Skip to content

Commit

Permalink
Fix CI (pytorch#1368)
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <matbet@meta.com>
Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
matteobettini and vmoens authored Jul 7, 2023
1 parent eb05c7b commit 98d2ca2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 72 deletions.
20 changes: 12 additions & 8 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,18 +944,22 @@ def forward(self, observation, action):
return self.linear(torch.cat([observation, action], dim=-1))


class CountingEnvCountPolicy(nn.Module):
class CountingEnvCountPolicy:
def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
super().__init__()
self.action_spec = action_spec
self.action_key = action_key

def __call__(self, t):
action = self.action_spec.zero() + 1
if isinstance(t, torch.Tensor):
return action
elif isinstance(t, TensorDictBase):
return t.set(self.action_key, action)
def __call__(self, td: TensorDictBase) -> TensorDictBase:
return td.set(self.action_key, self.action_spec.zero() + 1)


class CountingEnvCountModule(nn.Module):
def __init__(self, action_spec: TensorSpec):
super().__init__()
self.action_spec = action_spec

def forward(self, t):
return self.action_spec.zero() + 1


class CountingEnv(EnvBase):
Expand Down
4 changes: 2 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from _utils_internal import get_default_devices
from mocking_classes import (
ContinuousActionVecMockEnv,
CountingEnvCountPolicy,
CountingEnvCountModule,
NestedCountingEnv,
)
from scipy.stats import ttest_1samp
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_nested(

net = nn.LazyLinear(d_act).to(device)
policy = TensorDictModule(
CountingEnvCountPolicy(action_spec=action_spec, action_key=env.action_key),
CountingEnvCountModule(action_spec=action_spec),
in_keys=[("data", "states") if nested_obs_action else "observation"],
out_keys=[env.action_key],
)
Expand Down
124 changes: 62 additions & 62 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,85 +27,85 @@
import mlflow


@pytest.fixture
def tb_logger(tmp_path_factory):
tmpdir1 = tmp_path_factory.mktemp("tmpdir1")
exp_name = "ramala"
logger = TensorboardLogger(log_dir=tmpdir1, exp_name=exp_name)
yield logger
del logger


@pytest.mark.skipif(not _has_tb, reason="TensorBoard not installed")
class TestTensorboard:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps):
def test_log_scalar(self, steps, tb_logger):
torch.manual_seed(0)
with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name)

values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)
values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
tb_logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)

sleep(0.01) # wait until events are registered
sleep(0.01) # wait until events are registered

event_acc = EventAccumulator(logger.experiment.get_logdir())
event_acc.Reload()
assert len(event_acc.Scalars("foo")) == 3, str(event_acc.Scalars("foo"))
for i in range(3):
assert event_acc.Scalars("foo")[i].value == values[i]
if steps:
assert event_acc.Scalars("foo")[i].step == steps[i]
event_acc = EventAccumulator(tb_logger.experiment.get_logdir())
event_acc.Reload()
assert len(event_acc.Scalars("foo")) == 3, str(event_acc.Scalars("foo"))
for i in range(3):
assert event_acc.Scalars("foo")[i].value == values[i]
if steps:
assert event_acc.Scalars("foo")[i].step == steps[i]

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_video(self, steps):
def test_log_video(self, steps, tb_logger):
torch.manual_seed(0)
with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name)

# creating a sample video (T, C, H, W), where T - number of frames,
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
# the first 64 frames are black and the next 64 are white
video = torch.cat(
(torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255))
# creating a sample video (T, C, H, W), where T - number of frames,
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
# the first 64 frames are black and the next 64 are white
video = torch.cat(
(torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255))
)
video = video[None, :]
for i in range(3):
tb_logger.log_video(
name="foo",
video=video,
step=steps[i] if steps else None,
fps=6, # we can't test for the difference between fps, because the result is an encoded_string
)
video = video[None, :]
for i in range(3):
logger.log_video(
name="foo",
video=video,
step=steps[i] if steps else None,
fps=6, # we can't test for the difference between fps, because the result is an encoded_string
)

sleep(0.01) # wait until events are registered
sleep(0.01) # wait until events are registered

event_acc = EventAccumulator(logger.experiment.get_logdir())
event_acc.Reload()
assert len(event_acc.Images("foo")) == 3, str(event_acc.Images("foo"))
event_acc = EventAccumulator(tb_logger.experiment.get_logdir())
event_acc.Reload()
assert len(event_acc.Images("foo")) == 3, str(event_acc.Images("foo"))

# check that we catch the error in case the format of the tensor is wrong
# here the number of color channels is set to 2, which is not correct
video_wrong_format = torch.zeros(64, 2, 32, 32)
video_wrong_format = video_wrong_format[None, :]
with pytest.raises(Exception):
logger.log_video(
name="foo",
video=video_wrong_format,
step=steps[i] if steps else None,
)
# check that we catch the error in case the format of the tensor is wrong
# here the number of color channels is set to 2, which is not correct
video_wrong_format = torch.zeros(64, 2, 32, 32)
video_wrong_format = video_wrong_format[None, :]
with pytest.raises(Exception):
tb_logger.log_video(
name="foo",
video=video_wrong_format,
step=steps[i] if steps else None,
)

def test_log_histogram(self):
def test_log_histogram(self, tb_logger):
torch.manual_seed(0)
with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name)
# test with torch
data = torch.randn(10)
logger.log_histogram("hist", data, step=0, bins=2)
# test with np
data = torch.randn(10).numpy()
logger.log_histogram("hist", data, step=1, bins=2)
# test with torch
data = torch.randn(10)
tb_logger.log_histogram("hist", data, step=0, bins=2)
# test with np
data = torch.randn(10).numpy()
tb_logger.log_histogram("hist", data, step=1, bins=2)


class TestCSVLogger:
Expand Down

0 comments on commit 98d2ca2

Please sign in to comment.