From 74f9eefa26c4377a5be21b5eea35f6c1aa00f62a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Mar 2024 08:01:42 +0000 Subject: [PATCH 1/3] init --- .github/workflows/test-linux.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 84770c8ddf0..d2e13eddd63 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -51,7 +51,7 @@ jobs: tests-gpu: strategy: matrix: - python_version: ["3.8"] + python_version: ["3.10"] cuda_arch_version: ["12.1"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main @@ -119,7 +119,7 @@ jobs: tests-optdeps: strategy: matrix: - python_version: ["3.9"] # "3.8", "3.9", "3.10", "3.11" + python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11" cuda_arch_version: ["12.1"] # "11.6", "11.7" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main @@ -156,7 +156,7 @@ jobs: tests-stable-gpu: strategy: matrix: - python_version: ["3.8"] # "3.8", "3.9", "3.10", "3.11" + python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11" cuda_arch_version: ["11.8"] # "11.6", "11.7" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main From 7ceb91d76dd461f85f312bc048d3c531e9ae9cde Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 18 Mar 2024 11:43:21 +0100 Subject: [PATCH 2/3] add getitem method --- torchrl/collectors/distributed/ray.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index faf4d4a6cce..6ede4ef5a88 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -345,6 +345,9 @@ def check_consistency_with_num_collectors(param, param_name, num_collectors): remote_configs, "remote_config", num_collectors ) + def __class_getitem__(key): + return + def check_list_length_consistency(*lists): """Checks that all input lists have the same length. From 2147514e62e0970df3674ad354ce6a4d4dffe1aa Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 18 Mar 2024 12:17:54 +0100 Subject: [PATCH 3/3] add getitem method to parent --- torchrl/collectors/collectors.py | 3 +++ torchrl/collectors/distributed/ray.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index f4b92c87d9d..5cbfe5c62ae 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -235,6 +235,9 @@ def __repr__(self) -> str: string = f"{self.__class__.__name__}()" return string + def __class_getitem__(self, index): + raise NotImplementedError + @accept_remote_rref_udf_invocation class SyncDataCollector(DataCollectorBase): diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 6ede4ef5a88..faf4d4a6cce 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -345,9 +345,6 @@ def check_consistency_with_num_collectors(param, param_name, num_collectors): remote_configs, "remote_config", num_collectors ) - def __class_getitem__(key): - return - def check_list_length_consistency(*lists): """Checks that all input lists have the same length.