Skip to content

Commit

Permalink
test: enhance the code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
OmaymaMahjoub committed Dec 16, 2022
1 parent 8b7edef commit 102fe04
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
8 changes: 7 additions & 1 deletion tests/systems/systems_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def network_factory(*args: Any, **kwargs: Any) -> Any:
executor_parameter_update_period=1,
multi_process=True,
run_evaluator=True,
normalise_observations=False,
normalise_target_values=False,
num_executors=1,
max_queue_size=500,
use_next_extras=False,
Expand All @@ -313,7 +315,11 @@ def network_factory(*args: Any, **kwargs: Any) -> Any:
evaluation_interval={"executor_steps": 1000},
evaluation_duration={"evaluator_episodes": 5},
checkpoint_best_perf=True,
termination_condition={"executor_steps": 5000},
# Flag to activate the calculation of the absolute metric
absolute_metric=True,
# How many episodes the evaluator will run for
absolute_metric_duration=32,
termination_condition={"executor_steps": 10000},
wait=True,
)
return test_system
38 changes: 37 additions & 1 deletion tests/test_utils/checkpointing_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

import pytest

from mava.utils.checkpointing_utils import update_best_checkpoint, update_to_best_net
from mava.utils.checkpointing_utils import (
update_best_checkpoint,
update_evaluator_net,
update_to_best_net,
)


def fake_networks(k: int = 0) -> Tuple:
Expand Down Expand Up @@ -286,3 +290,35 @@ def test_update_to_best_net(mock_parameter_server: MockParameterServer) -> None:
del mock_parameter_server.store.parameters["best_checkpoint"]
with pytest.raises(Exception):
update_to_best_net(mock_parameter_server, "reward") # type:ignore


def test_update_evaluator_net(mock_executor: MockExecutor) -> None:
"""Test update_evaluator_net function"""
update_evaluator_net(mock_executor, "win_rate") # type:ignore

# Check that the networks got updated by the one belong to the win rate
for agent_net_key in mock_executor.store.networks.keys():
assert (
mock_executor.store.best_checkpoint["win_rate"][
f"policy_network-{agent_net_key}"
]
== mock_executor.store.networks[agent_net_key].policy_params
)
assert (
mock_executor.store.best_checkpoint["win_rate"][
f"critic_network-{agent_net_key}"
]
== mock_executor.store.networks[agent_net_key].critic_params
)
assert (
mock_executor.store.best_checkpoint["win_rate"][
f"policy_opt_state-{agent_net_key}"
]
== mock_executor.store.policy_opt_states[agent_net_key]
)
assert (
mock_executor.store.best_checkpoint["win_rate"][
f"critic_opt_state-{agent_net_key}"
]
== mock_executor.store.critic_opt_states[agent_net_key]
)

0 comments on commit 102fe04

Please sign in to comment.