forked from uoe-agents/epymarl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
116 lines (91 loc) · 3.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import os
import random
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver, MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml
from run import run
SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
logger = get_logger()
ex = Experiment("pymarl")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds
results_path = os.path.join(dirname(dirname(abspath(__file__))), "results")
# results_path = "/home/ubuntu/data"
@ex.main
def my_main(_run, _config, _log):
# Setting the random seed throughout the modules
config = config_copy(_config)
np.random.seed(config["seed"])
th.manual_seed(config["seed"])
config['env_args']['seed'] = config["seed"]
# run the framework
run(_run, config, _log)
def _get_config(params, arg_name, subfolder):
config_name = None
for _i, _v in enumerate(params):
if _v.split("=")[0] == arg_name:
config_name = _v.split("=")[1]
del params[_i]
break
if config_name is not None:
with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
try:
config_dict = yaml.load(f)
except yaml.YAMLError as exc:
assert False, "{}.yaml error: {}".format(config_name, exc)
return config_dict
def recursive_dict_update(d, u):
for k, v in u.items():
if isinstance(v, collections.Mapping):
d[k] = recursive_dict_update(d.get(k, {}), v)
else:
d[k] = v
return d
def config_copy(config):
if isinstance(config, dict):
return {k: config_copy(v) for k, v in config.items()}
elif isinstance(config, list):
return [config_copy(v) for v in config]
else:
return deepcopy(config)
if __name__ == '__main__':
params = deepcopy(sys.argv)
th.set_num_threads(1)
# Get the defaults from default.yaml
with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
try:
config_dict = yaml.load(f)
except yaml.YAMLError as exc:
assert False, "default.yaml error: {}".format(exc)
# Load algorithm and env base configs
env_config = _get_config(params, "--env-config", "envs")
alg_config = _get_config(params, "--config", "algs")
# config_dict = {**config_dict, **env_config, **alg_config}
config_dict = recursive_dict_update(config_dict, env_config)
config_dict = recursive_dict_update(config_dict, alg_config)
try:
map_name = config_dict["env_args"]["map_name"]
except:
map_name = config_dict["env_args"]["key"]
# now add all the config to sacred
ex.add_config(config_dict)
for param in params:
if param.startswith("env_args.map_name"):
map_name = param.split("=")[1]
elif param.startswith("env_args.key"):
map_name = param.split("=")[1]
# Save to disk by default for sacred
logger.info("Saving to FileStorageObserver in results/sacred.")
file_obs_path = os.path.join(results_path, f"sacred/{config_dict['name']}/{map_name}")
# ex.observers.append(MongoObserver(db_name="marlbench")) #url='172.31.5.187:27017'))
ex.observers.append(FileStorageObserver.create(file_obs_path))
# ex.observers.append(MongoObserver())
ex.run_commandline(params)