-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts.ml
163 lines (129 loc) · 4.83 KB
/
mcts.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
open Ucb1
type 'a gen = Random.State.t -> 'a
module type S = sig
(** Type of states. *)
type terminal
type nonterminal
(** States must be distinguished as either terminal or nonterminal. *)
type state = Terminal of terminal | Nonterminal of nonterminal
(** Actions that can be taken at each state. *)
type action
(** Actions available at a given state. *)
val actions : nonterminal -> action array
(** Given a state and an action, one can move to the next state. *)
val next : nonterminal -> action -> state
(** Reward at a terminal state. *)
val reward : terminal -> float
(** The MCTS is parameterised by a Monte-Carlo exploration. Setting
to [`Uniform] will use an uniform search. *)
val exploration_depth : [ `Unbounded | `Bounded of int ]
val exploration_kernel : [ `Uniform | `Kernel of nonterminal -> state gen ]
val pp_action : Format.formatter -> action -> unit
val pp_terminal : Format.formatter -> terminal -> unit
val pp_nonterminal : Format.formatter -> nonterminal -> unit
end
(** Monte-Carlo Tree Search yields a policy, i.e. a way to decide which
action to take at each nonterminal state. *)
module type Policy = sig
type t
type action
val policy : playouts:int -> t -> action gen
end
module MCTS : functor (X : S) ->
Policy with type t = X.nonterminal and type action = X.action =
functor
(X : S)
->
struct
type t = X.nonterminal
type action = X.action
module Bandit = Ucb1.Make (struct
type t = int
let compare (x : int) (y : int) =
if x < y then -1 else if x > y then 1 else 0
let pp = Format.pp_print_int
end)
type tree =
| Terminal of terminal_node
| Nonterminal of nonterminal_node
| Unexplored of X.nonterminal
and bandit = ready_to_move Bandit.t
and nonterminal_node =
{ state : X.nonterminal;
mutable bandit : bandit;
actions : action array;
branches : tree lazy_t array
}
and terminal_node = { final : X.terminal; reward : float }
let uniform_exploration : t -> X.state gen =
fun state ->
let actions = X.actions state in
fun rng_state ->
let act =
let index = Random.State.int rng_state (Array.length actions) in
actions.(index)
in
X.next state act
let exploration =
match X.exploration_kernel with
| `Uniform -> uniform_exploration
| `Kernel f -> f
let rec explore_until_termination (node : X.nonterminal) rng_state =
let next = exploration node rng_state in
match next with
| X.Terminal state -> X.reward state
| X.Nonterminal state -> explore_until_termination state rng_state
let rec explore_until_termination_bounded (gas : int) (node : X.nonterminal)
rng_state =
if gas < 0 then 0.0
else
let next = exploration node rng_state in
match next with
| X.Terminal state -> X.reward state
| X.Nonterminal state ->
explore_until_termination_bounded (gas - 1) state rng_state
let exploration_loop =
match X.exploration_depth with
| `Unbounded -> explore_until_termination
| `Bounded gas -> explore_until_termination_bounded gas
let rec assign_reward path reward =
match path with
| [] -> ()
| (node, bandit) :: tl ->
let bandit = Bandit.set_reward bandit reward in
node.bandit <- bandit ;
assign_reward tl reward
let rec playout (node : nonterminal_node) path rng_state =
let (act, awaiting) = Bandit.next_action node.bandit in
let path = (node, awaiting) :: path in
match Lazy.force node.branches.(act) with
| Terminal { reward; _ } -> assign_reward path reward
| Nonterminal node' -> playout node' path rng_state
| Unexplored nonterminal ->
let new_node = expand_node nonterminal in
node.branches.(act) <- Lazy.from_val (Nonterminal new_node) ;
let reward = exploration_loop nonterminal rng_state in
assign_reward path reward
and expand_node nonterminal =
let actions = X.actions nonterminal in
let arms = Array.init (Array.length actions) (fun i -> i) in
let bandit = Bandit.create arms in
let branches =
Array.map
(fun act ->
Lazy.from_fun (fun () ->
match X.next nonterminal act with
| X.Terminal final ->
Terminal { final; reward = X.reward final }
| X.Nonterminal state -> Unexplored state))
actions
in
{ state = nonterminal; bandit; actions; branches }
let policy ~playouts initial_state rng_state =
let root = expand_node initial_state in
for _i = 0 to playouts - 1 do
playout root [] rng_state
done ;
let (act, _) = Bandit.next_action root.bandit in
root.actions.(act)
end