forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multi_agent_independent_learning.py
60 lines (51 loc) · 1.56 KB
/
multi_agent_independent_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
from ray import air, tune
from ray.tune.registry import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from pettingzoo.sisl import waterworld_v4
# Based on code from github.com/parametersharingmadrl/parametersharingmadrl
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-gpus",
type=int,
default=1,
help="Number of GPUs to use for training.",
)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: Only one episode will be "
"sampled.",
)
if __name__ == "__main__":
args = parser.parse_args()
def env_creator(args):
return PettingZooEnv(waterworld_v4.env())
env = env_creator({})
register_env("waterworld", env_creator)
config = (
PPOConfig()
.environment("waterworld")
.resources(num_gpus=args.num_gpus)
.rollouts(num_rollout_workers=2)
.multi_agent(
policies=env.get_agent_ids(),
policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
)
)
if args.as_test:
# Only a compilation test of running waterworld / independent learning.
stop = {"training_iteration": 1}
else:
stop = {"episodes_total": 60000}
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop=stop,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()