-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate.py
112 lines (88 loc) · 3.19 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import envs.battle_env as battle_env
import instinct.team as instinct
import maddpg.team as maddpg
import matplotlib.pyplot as plt
import json
import os
import shutil
import sys
def merge_dicts(dict1, dict2):
dict2.update(dict1)
return dict2
def main():
model_name = input('Enter model name: ')
FOLDER = f'models/{model_name}'
if not os.path.exists(FOLDER):
print('Model does not exist')
sys.exit()
params = {}
env_config = {}
# Load params and env_config and scores
with open(f'{FOLDER}/params.json', 'r') as f:
params = json.load(f)
with open(f'{FOLDER}/cf.json', 'r') as f:
env_config = json.load(f)
env_config['show'] = False
env = battle_env.parallel_env(**env_config)
red_agent_list = env.possible_red
blue_agent_list = env.possible_blue
obs_len = env.observation_space(red_agent_list[0]).shape[0]
critic_dims = obs_len * env.n_agents
# Red team is the maddpg team
red_team = maddpg.Team(red_agent_list, obs_len, env.n_actions, critic_dims, params['fc1_dims'], params['fc2_dims'], params['buffer_size'], params['batch_size'], params['gamma'], params['lr'], FOLDER)
red_team.load_models()
# Blue team is the instinct team
blue_team = instinct.Team(blue_agent_list, red_agent_list, env)
wins = {
"red": 0,
"blue": 0,
"tie": 0
}
for _ in range(10000):
observations = env.reset()
red_obs = {}
blue_obs = {}
red_obs_ = {}
blue_obs_ = {}
for agent in red_agent_list:
red_obs[agent] = observations[agent]
for agent in blue_agent_list:
blue_obs[agent] = observations[agent]
observations = env.reset()
actions = {}
while not env.env_done:
actions = merge_dicts(red_team.choose_actions(red_obs), blue_team.choose_actions(blue_obs))
observations_, rewards, dones, _ = env.step(actions)
for agent in red_agent_list:
red_obs_[agent] = observations_[agent]
for agent in blue_agent_list:
blue_obs_[agent] = observations_[agent]
red_obs = red_obs_
blue_obs = blue_obs_
wins[env.winner] += 1
env.show = True
env.start_recording(path=f'{FOLDER}/eval_videos/')
for _ in range(10):
observations = env.reset()
red_obs = {}
blue_obs = {}
red_obs_ = {}
blue_obs_ = {}
for agent in red_agent_list:
red_obs[agent] = observations[agent]
for agent in blue_agent_list:
blue_obs[agent] = observations[agent]
observations = env.reset()
actions = {}
while not env.env_done:
actions = merge_dicts(red_team.choose_actions(red_obs), blue_team.choose_actions(blue_obs))
observations_, rewards, dones, _ = env.step(actions)
for agent in red_agent_list:
red_obs_[agent] = observations_[agent]
for agent in blue_agent_list:
blue_obs_[agent] = observations_[agent]
red_obs = red_obs_
blue_obs = blue_obs_
env.stop_recording()
if __name__ == '__main__':
main()