forked from facebookresearch/BenchMARL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_benchmark.py
90 lines (76 loc) · 2.89 KB
/
plot_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
from pathlib import Path
from typing import List
from benchmarl.eval_results import load_and_merge_json_dicts, Plotting
from matplotlib import pyplot as plt
def run_benchmark() -> List[str]:
from benchmarl.algorithms import MappoConfig, QmixConfig
from benchmarl.benchmark import Benchmark
from benchmarl.environments import VmasTask
from benchmarl.experiment import ExperimentConfig
from benchmarl.models.mlp import MlpConfig
# Configure experiment
experiment_config = ExperimentConfig.get_from_yaml()
experiment_config.save_folder = Path(os.path.dirname(os.path.realpath(__file__)))
experiment_config.loggers = []
experiment_config.max_n_iters = 100
# Configure benchmark
tasks = [VmasTask.NAVIGATION.get_from_yaml()]
algorithm_configs = [
MappoConfig.get_from_yaml(),
QmixConfig.get_from_yaml(),
]
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()
benchmark = Benchmark(
algorithm_configs=algorithm_configs,
tasks=tasks,
seeds={0, 1},
experiment_config=experiment_config,
model_config=model_config,
critic_model_config=critic_model_config,
)
# For each experiment, run it and get its output file name
experiments = benchmark.get_experiments()
experiments_json_files = []
for experiment in experiments:
exp_json_file = str(
Path(experiment.folder_name) / Path(experiment.name + ".json")
)
experiments_json_files.append(exp_json_file)
experiment.run()
return experiments_json_files
if __name__ == "__main__":
# Uncomment this to rerun the benchmark that generates the files
experiments_json_files = run_benchmark()
raw_dict = load_and_merge_json_dicts(experiments_json_files)
# Load and process experiment outputs
# raw_dict = load_and_merge_json_dicts(experiments_json_files)
processed_data = Plotting.process_data(raw_dict)
(
environment_comparison_matrix,
sample_efficiency_matrix,
) = Plotting.create_matrices(processed_data, env_name="vmas")
# Plotting
Plotting.performance_profile_figure(
environment_comparison_matrix=environment_comparison_matrix
)
Plotting.aggregate_scores(
environment_comparison_matrix=environment_comparison_matrix
)
Plotting.environemnt_sample_efficiency_curves(
sample_effeciency_matrix=sample_efficiency_matrix
)
Plotting.task_sample_efficiency_curves(
processed_data=processed_data, env="vmas", task="navigation"
)
Plotting.probability_of_improvement(
environment_comparison_matrix,
algorithms_to_compare=[["qmix", "mappo"]],
)
plt.show()