Skip to content

Commit

Permalink
Merge branch 'develop' into feat/integ-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Jul 23, 2024
2 parents 6ef1f5e + 7267aa3 commit 92855b2
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests_linters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
tests-and-linters:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
timeout-minutes: 20
timeout-minutes: 10

strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \
# Need to use specific cuda versions for jax
ARG USE_CUDA=true
RUN if [ "$USE_CUDA" = true ] ; \
then pip install "jax[cuda12]==0.4.26" ; \
then pip install "jax[cuda12]==0.4.30" ; \
fi

# Copy all code
Expand Down
2 changes: 1 addition & 1 deletion docs/DETAILED_INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pip install -e .

4. Install jax on your accelerator. The example below is for an NVIDIA GPU, please the [official install guide](https://github.com/google/jax#installation) for other accelerators
```bash
pip install "jax[cuda12]==0.4.26"
pip install "jax[cuda12]==0.4.30"
```

5. Run a system!
Expand Down
22 changes: 22 additions & 0 deletions mava/wrappers/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
discount = jnp.repeat(timestep.discount, self.num_agents)
return timestep.replace(observation=observation, reward=reward, discount=discount)

def observation_spec(
self,
) -> specs.Spec[Union[Observation, ObservationGlobalState]]:
# need to cast the agents view and global state to floats as we do in modify timestep
inner_spec = super().observation_spec()
spec = inner_spec.replace(agents_view=inner_spec.agents_view.replace(dtype=float))
if self.add_global_state:
spec = inner_spec.replace(global_state=inner_spec.global_state.replace(dtype=float))

return spec


class LbfWrapper(JumanjiMarlWrapper):
"""Multi-agent wrapper for the Level-Based Foraging environment.
Expand Down Expand Up @@ -192,6 +203,17 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
# Aggregate the list of individual rewards and use a single team_reward.
return self.aggregate_rewards(timestep, modified_observation)

def observation_spec(
self,
) -> specs.Spec[Union[Observation, ObservationGlobalState]]:
# need to cast the agents view and global state to floats as we do in modify timestep
inner_spec = super().observation_spec()
spec = inner_spec.replace(agents_view=inner_spec.agents_view.replace(dtype=float))
if self.add_global_state:
spec = inner_spec.replace(global_state=inner_spec.global_state.replace(dtype=float))

return spec


class ConnectorWrapper(JumanjiMarlWrapper):
"""Multi-agent wrapper for the MA Connector environment.
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ flax
gigastep @ git+https://github.com/mlech26l/gigastep
hydra-core==1.3.2
id-marl-eval @ git+https://github.com/instadeepai/marl-eval
jax==0.4.26 # Fixed due to Mabrax dependency
jaxlib==0.4.26
jax==0.4.30
jaxlib==0.4.30
jaxmarl
jumanji @ git+https://github.com/sash-a/jumanji # Includes a few extra MARL envs
matrax @ git+https://github.com/instadeepai/matrax
Expand Down

0 comments on commit 92855b2

Please sign in to comment.