forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_train_fn.py
67 lines (56 loc) · 1.85 KB
/
custom_train_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Example of a custom training workflow. Run this for a demo.
This example shows:
- using Tune trainable functions to implement custom training workflows
You can visualize experiment results in ~/ray_results using TensorBoard.
"""
import argparse
import os
import ray
from ray import train, tune
from ray.rllib.algorithms.ppo import PPO, PPOConfig
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
def my_train_fn(config):
iterations = config.pop("train-iterations", 10)
config = PPOConfig().update_from_dict(config).environment("CartPole-v1")
# Train for n iterations with high LR.
config.lr = 0.01
agent1 = config.build()
for _ in range(iterations):
result = agent1.train()
result["phase"] = 1
train.report(result)
phase1_time = result["timesteps_total"]
state = agent1.save()
agent1.stop()
# Train for n iterations with low LR
config.lr = 0.0001
agent2 = config.build()
agent2.restore(state)
for _ in range(iterations):
result = agent2.train()
result["phase"] = 2
result["timesteps_total"] += phase1_time # keep time moving forward
train.report(result)
agent2.stop()
if __name__ == "__main__":
ray.init()
args = parser.parse_args()
config = {
# Special flag signalling `my_train_fn` how many iters to do.
"train-iterations": 2,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
"framework": args.framework,
}
resources = PPO.default_resource_request(config)
tuner = tune.Tuner(
tune.with_resources(my_train_fn, resources=resources), param_space=config
)
tuner.fit()