diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d12e3d40069..6d5107fcc64 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -944,18 +944,22 @@ def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) -class CountingEnvCountPolicy(nn.Module): +class CountingEnvCountPolicy: def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - super().__init__() self.action_spec = action_spec self.action_key = action_key - def __call__(self, t): - action = self.action_spec.zero() + 1 - if isinstance(t, torch.Tensor): - return action - elif isinstance(t, TensorDictBase): - return t.set(self.action_key, action) + def __call__(self, td: TensorDictBase) -> TensorDictBase: + return td.set(self.action_key, self.action_spec.zero() + 1) + + +class CountingEnvCountModule(nn.Module): + def __init__(self, action_spec: TensorSpec): + super().__init__() + self.action_spec = action_spec + + def forward(self, t): + return self.action_spec.zero() + 1 class CountingEnv(EnvBase): diff --git a/test/test_exploration.py b/test/test_exploration.py index f8181406349..c823dbaf4f4 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -10,7 +10,7 @@ from _utils_internal import get_default_devices from mocking_classes import ( ContinuousActionVecMockEnv, - CountingEnvCountPolicy, + CountingEnvCountModule, NestedCountingEnv, ) from scipy.stats import ttest_1samp @@ -217,7 +217,7 @@ def test_nested( net = nn.LazyLinear(d_act).to(device) policy = TensorDictModule( - CountingEnvCountPolicy(action_spec=action_spec, action_key=env.action_key), + CountingEnvCountModule(action_spec=action_spec), in_keys=[("data", "states") if nested_obs_action else "observation"], out_keys=[env.action_key], ) diff --git a/test/test_loggers.py b/test/test_loggers.py index 090725f0fa0..a4937dd0fc3 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -27,85 +27,85 @@ import mlflow +@pytest.fixture +def tb_logger(tmp_path_factory): + tmpdir1 = tmp_path_factory.mktemp("tmpdir1") + exp_name = "ramala" + logger = TensorboardLogger(log_dir=tmpdir1, exp_name=exp_name) + yield logger + del logger + + @pytest.mark.skipif(not _has_tb, reason="TensorBoard not installed") class TestTensorboard: @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) - def test_log_scalar(self, steps): + def test_log_scalar(self, steps, tb_logger): torch.manual_seed(0) - with tempfile.TemporaryDirectory() as log_dir: - exp_name = "ramala" - logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name) - values = torch.rand(3) - for i in range(3): - scalar_name = "foo" - scalar_value = values[i].item() - logger.log_scalar( - value=scalar_value, - name=scalar_name, - step=steps[i] if steps else None, - ) + values = torch.rand(3) + for i in range(3): + scalar_name = "foo" + scalar_value = values[i].item() + tb_logger.log_scalar( + value=scalar_value, + name=scalar_name, + step=steps[i] if steps else None, + ) - sleep(0.01) # wait until events are registered + sleep(0.01) # wait until events are registered - event_acc = EventAccumulator(logger.experiment.get_logdir()) - event_acc.Reload() - assert len(event_acc.Scalars("foo")) == 3, str(event_acc.Scalars("foo")) - for i in range(3): - assert event_acc.Scalars("foo")[i].value == values[i] - if steps: - assert event_acc.Scalars("foo")[i].step == steps[i] + event_acc = EventAccumulator(tb_logger.experiment.get_logdir()) + event_acc.Reload() + assert len(event_acc.Scalars("foo")) == 3, str(event_acc.Scalars("foo")) + for i in range(3): + assert event_acc.Scalars("foo")[i].value == values[i] + if steps: + assert event_acc.Scalars("foo")[i].step == steps[i] @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) - def test_log_video(self, steps): + def test_log_video(self, steps, tb_logger): torch.manual_seed(0) - with tempfile.TemporaryDirectory() as log_dir: - exp_name = "ramala" - logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name) - # creating a sample video (T, C, H, W), where T - number of frames, - # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. - # the first 64 frames are black and the next 64 are white - video = torch.cat( - (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + # creating a sample video (T, C, H, W), where T - number of frames, + # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. + # the first 64 frames are black and the next 64 are white + video = torch.cat( + (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + ) + video = video[None, :] + for i in range(3): + tb_logger.log_video( + name="foo", + video=video, + step=steps[i] if steps else None, + fps=6, # we can't test for the difference between fps, because the result is an encoded_string ) - video = video[None, :] - for i in range(3): - logger.log_video( - name="foo", - video=video, - step=steps[i] if steps else None, - fps=6, # we can't test for the difference between fps, because the result is an encoded_string - ) - sleep(0.01) # wait until events are registered + sleep(0.01) # wait until events are registered - event_acc = EventAccumulator(logger.experiment.get_logdir()) - event_acc.Reload() - assert len(event_acc.Images("foo")) == 3, str(event_acc.Images("foo")) + event_acc = EventAccumulator(tb_logger.experiment.get_logdir()) + event_acc.Reload() + assert len(event_acc.Images("foo")) == 3, str(event_acc.Images("foo")) - # check that we catch the error in case the format of the tensor is wrong - # here the number of color channels is set to 2, which is not correct - video_wrong_format = torch.zeros(64, 2, 32, 32) - video_wrong_format = video_wrong_format[None, :] - with pytest.raises(Exception): - logger.log_video( - name="foo", - video=video_wrong_format, - step=steps[i] if steps else None, - ) + # check that we catch the error in case the format of the tensor is wrong + # here the number of color channels is set to 2, which is not correct + video_wrong_format = torch.zeros(64, 2, 32, 32) + video_wrong_format = video_wrong_format[None, :] + with pytest.raises(Exception): + tb_logger.log_video( + name="foo", + video=video_wrong_format, + step=steps[i] if steps else None, + ) - def test_log_histogram(self): + def test_log_histogram(self, tb_logger): torch.manual_seed(0) - with tempfile.TemporaryDirectory() as log_dir: - exp_name = "ramala" - logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name) - # test with torch - data = torch.randn(10) - logger.log_histogram("hist", data, step=0, bins=2) - # test with np - data = torch.randn(10).numpy() - logger.log_histogram("hist", data, step=1, bins=2) + # test with torch + data = torch.randn(10) + tb_logger.log_histogram("hist", data, step=0, bins=2) + # test with np + data = torch.randn(10).numpy() + tb_logger.log_histogram("hist", data, step=1, bins=2) class TestCSVLogger: