forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
complex_struct_space.py
57 lines (48 loc) · 1.59 KB
/
complex_struct_space.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Example of using variable-length Repeated / struct observation spaces.
This example shows:
- using a custom environment with Repeated / struct observations
- using a custom model to view the batched list observations
For PyTorch / TF eager mode, use the `--framework=[torch|tf2]` flag.
"""
import argparse
import os
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.env.simple_rpg import SimpleRPG
from ray.rllib.examples.models.simple_rpg_model import (
CustomTorchRPGModel,
CustomTFRPGModel,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="tf2",
help="The DL framework specifier.",
)
if __name__ == "__main__":
ray.init()
args = parser.parse_args()
if args.framework == "torch":
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
else:
ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)
config = (
PPOConfig()
.environment(SimpleRPG)
.framework(args.framework)
.rollouts(rollout_fragment_length=1, num_rollout_workers=0)
.training(train_batch_size=2, model={"custom_model": "my_model"})
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
stop = {
"timesteps_total": 1,
}
tuner = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(stop=stop, verbose=1),
)