From 19dfefc84ec9e8998b7ef6e97578fe186372d48f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 11 Dec 2024 09:15:52 -0800 Subject: [PATCH] [BugFix] Fix init_random_frames=0 ghstack-source-id: 38a544ea15631f9affb4c385c09e7c4df94af55d Pull Request resolved: https://github.com/pytorch/rl/pull/2645 --- test/test_collector.py | 2 +- torchrl/collectors/collectors.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 38191a46eaa..5c91cb83633 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1345,7 +1345,7 @@ def make_env(): functools.partial(MultiSyncDataCollector, cat_results="stack"), ], ) -@pytest.mark.parametrize("init_random_frames", [50]) # 1226: faster execution +@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution @pytest.mark.parametrize( "explicit_spec,split_trajs", [[True, True], [False, False]] ) # 1226: faster execution diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 16eb5904b84..14fbc7d5f22 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -712,10 +712,10 @@ def __init__( ) self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 + int(init_random_frames) if init_random_frames not in (None, -1) else 0 ) if ( - init_random_frames is not None + init_random_frames not in (-1, None, 0) and init_random_frames % frames_per_batch != 0 and RL_WARNINGS ):