diff --git a/.github/workflows/tests_linters.yaml b/.github/workflows/tests_linters.yaml index 5eeefb71f..12356fe95 100644 --- a/.github/workflows/tests_linters.yaml +++ b/.github/workflows/tests_linters.yaml @@ -6,7 +6,7 @@ jobs: tests-and-linters: name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" runs-on: "${{ matrix.os }}" - timeout-minutes: 20 + timeout-minutes: 10 strategy: matrix: diff --git a/Dockerfile b/Dockerfile index 408834ec4..baa8c1e4c 100755 --- a/Dockerfile +++ b/Dockerfile @@ -36,7 +36,7 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \ # Need to use specific cuda versions for jax ARG USE_CUDA=true RUN if [ "$USE_CUDA" = true ] ; \ - then pip install "jax[cuda12]==0.4.26" ; \ + then pip install "jax[cuda12]==0.4.30" ; \ fi # Copy all code diff --git a/docs/DETAILED_INSTALL.md b/docs/DETAILED_INSTALL.md index 82fc33608..28547c8aa 100644 --- a/docs/DETAILED_INSTALL.md +++ b/docs/DETAILED_INSTALL.md @@ -22,7 +22,7 @@ pip install -e . 4. Install jax on your accelerator. The example below is for an NVIDIA GPU, please the [official install guide](https://github.com/google/jax#installation) for other accelerators ```bash -pip install "jax[cuda12]==0.4.26" +pip install "jax[cuda12]==0.4.30" ``` 5. Run a system! diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 120804b51..1393566c1 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -143,6 +143,17 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: discount = jnp.repeat(timestep.discount, self.num_agents) return timestep.replace(observation=observation, reward=reward, discount=discount) + def observation_spec( + self, + ) -> specs.Spec[Union[Observation, ObservationGlobalState]]: + # need to cast the agents view and global state to floats as we do in modify timestep + inner_spec = super().observation_spec() + spec = inner_spec.replace(agents_view=inner_spec.agents_view.replace(dtype=float)) + if self.add_global_state: + spec = inner_spec.replace(global_state=inner_spec.global_state.replace(dtype=float)) + + return spec + class LbfWrapper(JumanjiMarlWrapper): """Multi-agent wrapper for the Level-Based Foraging environment. @@ -192,6 +203,17 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: # Aggregate the list of individual rewards and use a single team_reward. return self.aggregate_rewards(timestep, modified_observation) + def observation_spec( + self, + ) -> specs.Spec[Union[Observation, ObservationGlobalState]]: + # need to cast the agents view and global state to floats as we do in modify timestep + inner_spec = super().observation_spec() + spec = inner_spec.replace(agents_view=inner_spec.agents_view.replace(dtype=float)) + if self.add_global_state: + spec = inner_spec.replace(global_state=inner_spec.global_state.replace(dtype=float)) + + return spec + class ConnectorWrapper(JumanjiMarlWrapper): """Multi-agent wrapper for the MA Connector environment. diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 4c6f4665a..8dbcaeca9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -6,8 +6,8 @@ flax gigastep @ git+https://github.com/mlech26l/gigastep hydra-core==1.3.2 id-marl-eval @ git+https://github.com/instadeepai/marl-eval -jax==0.4.26 # Fixed due to Mabrax dependency -jaxlib==0.4.26 +jax==0.4.30 +jaxlib==0.4.30 jaxmarl jumanji @ git+https://github.com/sash-a/jumanji # Includes a few extra MARL envs matrax @ git+https://github.com/instadeepai/matrax