-
Notifications
You must be signed in to change notification settings - Fork 1
/
ddpg.py
321 lines (276 loc) · 12.2 KB
/
ddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
from replay_mem import ReplayMem
from utils import discount_return, sample_rewards
import rllab.misc.logger as logger
import pyprind
import mxnet as mx
import numpy as np
class DDPG(object):
def __init__(
self,
env,
policy,
qfunc,
strategy,
ctx=mx.gpu(0),
batch_size=32,
n_epochs=1000,
epoch_length=1000,
memory_size=1000000,
memory_start_size=1000,
discount=0.99,
max_path_length=1000,
eval_samples=10000,
qfunc_updater="adam",
qfunc_lr=1e-4,
policy_updater="adam",
policy_lr=1e-4,
soft_target_tau=1e-3,
n_updates_per_sample=1,
include_horizon_terminal=False,
seed=12345):
mx.random.seed(seed)
np.random.seed(seed)
self.env = env
self.ctx = ctx
self.policy = policy
self.qfunc = qfunc
self.strategy = strategy
self.batch_size = batch_size
self.n_epochs = n_epochs
self.epoch_length = epoch_length
self.memory_size = memory_size
self.memory_start_size = memory_start_size
self.discount = discount
self.max_path_length = max_path_length
self.eval_samples = eval_samples
self.qfunc_updater = qfunc_updater
self.qfunc_lr = qfunc_lr
self.policy_updater = policy_updater
self.policy_lr = policy_lr
self.soft_target_tau = soft_target_tau
self.n_updates_per_sample = n_updates_per_sample
self.include_horizon_terminal = include_horizon_terminal
self.init_net()
# logging
self.qfunc_loss_averages = []
self.policy_loss_averages = []
self.q_averages = []
self.y_averages = []
self.strategy_path_returns = []
def init_net(self):
# qfunc init
qfunc_init = mx.initializer.Normal()
loss_symbols = self.qfunc.get_loss_symbols()
qval_sym = loss_symbols["qval"]
yval_sym = loss_symbols["yval"]
# define loss here
loss = 1.0 / self.batch_size * mx.symbol.sum(
mx.symbol.square(qval_sym - yval_sym))
qfunc_loss = loss
qfunc_updater = mx.optimizer.get_updater(
mx.optimizer.create(self.qfunc_updater,
learning_rate=self.qfunc_lr))
self.qfunc_input_shapes = {
"obs": (self.batch_size, self.env.observation_space.flat_dim),
"act": (self.batch_size, self.env.action_space.flat_dim),
"yval": (self.batch_size, 1)}
self.qfunc.define_loss(qfunc_loss)
self.qfunc.define_exe(
ctx=self.ctx,
init=qfunc_init,
updater=qfunc_updater,
input_shapes=self.qfunc_input_shapes)
# qfunc_target init
qfunc_target_shapes = {
"obs": (self.batch_size, self.env.observation_space.flat_dim),
"act": (self.batch_size, self.env.action_space.flat_dim)
}
self.qfunc_target = qval_sym.simple_bind(ctx=self.ctx,
**qfunc_target_shapes)
# parameters are not shared but initialized the same
for name, arr in self.qfunc_target.arg_dict.items():
if name not in self.qfunc_input_shapes:
self.qfunc.arg_dict[name].copyto(arr)
# policy init
policy_init = mx.initializer.Normal()
loss_symbols = self.policy.get_loss_symbols()
act_sym = loss_symbols["act"]
policy_qval = qval_sym
# note the negative one here: the loss maximizes the average return
loss = -1.0 / self.batch_size * mx.symbol.sum(policy_qval)
policy_loss = loss
policy_loss = mx.symbol.MakeLoss(policy_loss, name="policy_loss")
policy_updater = mx.optimizer.get_updater(
mx.optimizer.create(self.policy_updater,
learning_rate=self.policy_lr))
self.policy_input_shapes = {
"obs": (self.batch_size, self.env.observation_space.flat_dim)}
self.policy.define_exe(
ctx=self.ctx,
init=policy_init,
updater=policy_updater,
input_shapes=self.policy_input_shapes)
# policy network and q-value network are combined to backpropage
# gradients from the policy loss
# since the loss is different, yval is not needed
args = {}
for name, arr in self.qfunc.arg_dict.items():
if name != "yval":
args[name] = arr
args_grad = {}
policy_grad_dict = dict(zip(self.qfunc.loss.list_arguments(), self.qfunc.exe.grad_arrays))
for name, arr in policy_grad_dict.items():
if name != "yval":
args_grad[name] = arr
self.policy_executor = policy_loss.bind(
ctx=self.ctx,
args=args,
args_grad=args_grad,
grad_req="write")
self.policy_executor_arg_dict = self.policy_executor.arg_dict
self.policy_executor_grad_dict = dict(zip(
policy_loss.list_arguments(),
self.policy_executor.grad_arrays))
# policy_target init
# target policy only needs to produce actions, not loss
# parameters are not shared but initialized the same
self.policy_target = act_sym.simple_bind(ctx=self.ctx,
**self.policy_input_shapes)
for name, arr in self.policy_target.arg_dict.items():
if name not in self.policy_input_shapes:
self.policy.arg_dict[name].copyto(arr)
def train(self):
memory = ReplayMem(
obs_dim=self.env.observation_space.flat_dim,
act_dim=self.env.action_space.flat_dim,
memory_size=self.memory_size)
itr = 0
path_length = 0
path_return = 0
end = False
obs = self.env.reset()
for epoch in xrange(self.n_epochs):
logger.push_prefix("epoch #%d | " % epoch)
logger.log("Training started")
for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
# run the policy
if end:
# reset the environment and stretegy when an episode ends
obs = self.env.reset()
self.strategy.reset()
# self.policy.reset()
self.strategy_path_returns.append(path_return)
path_length = 0
path_return = 0
# note action is sampled from the policy not the target policy
act = self.strategy.get_action(obs, self.policy)
nxt, rwd, end, _ = self.env.step(act)
path_length += 1
path_return += rwd
if not end and path_length >= self.max_path_length:
end = True
if self.include_horizon_terminal:
memory.add_sample(obs, act, rwd, end)
else:
memory.add_sample(obs, act, rwd, end)
obs = nxt
if memory.size >= self.memory_start_size:
for update_time in xrange(self.n_updates_per_sample):
batch = memory.get_batch(self.batch_size)
self.do_update(itr, batch)
itr += 1
logger.log("Training finished")
if memory.size >= self.memory_start_size:
self.evaluate(epoch, memory)
logger.dump_tabular(with_prefix=False)
logger.pop_prefix()
# self.env.terminate()
# self.policy.terminate()
def do_update(self, itr, batch):
obss, acts, rwds, ends, nxts = batch
self.policy_target.arg_dict["obs"][:] = nxts
self.policy_target.forward(is_train=False)
next_acts = self.policy_target.outputs[0].asnumpy()
policy_acts = self.policy.get_actions(obss)
self.qfunc_target.arg_dict["obs"][:] = nxts
self.qfunc_target.arg_dict["act"][:] = next_acts
self.qfunc_target.forward(is_train=False)
next_qvals = self.qfunc_target.outputs[0].asnumpy()
# executor accepts 2D tensors
rwds = rwds.reshape((-1, 1))
ends = ends.reshape((-1, 1))
ys = rwds + (1.0 - ends) * self.discount * next_qvals
# since policy_executor shares the grad arrays with qfunc
# the update order could not be changed
self.qfunc.update_params(obss, acts, ys)
# in update values all computed
# no need to recompute qfunc_loss and qvals
qfunc_loss = self.qfunc.exe.outputs[0].asnumpy()
qvals = self.qfunc.exe.outputs[1].asnumpy()
self.policy_executor.arg_dict["obs"][:] = obss
self.policy_executor.arg_dict["act"][:] = policy_acts
self.policy_executor.forward(is_train=True)
policy_loss = self.policy_executor.outputs[0].asnumpy()
self.policy_executor.backward()
self.policy.update_params(self.policy_executor_grad_dict["act"])
# update target networks
for name, arr in self.policy_target.arg_dict.items():
if name not in self.policy_input_shapes:
arr[:] = (1.0 - self.soft_target_tau) * arr[:] + \
self.soft_target_tau * self.policy.arg_dict[name][:]
for name, arr in self.qfunc_target.arg_dict.items():
if name not in self.qfunc_input_shapes:
arr[:] = (1.0 - self.soft_target_tau) * arr[:] + \
self.soft_target_tau * self.qfunc.arg_dict[name][:]
self.qfunc_loss_averages.append(qfunc_loss)
self.policy_loss_averages.append(policy_loss)
self.q_averages.append(qvals)
self.y_averages.append(ys)
def evaluate(self, epoch, memory):
if epoch == self.n_epochs - 1:
logger.log("Collecting samples for evaluation")
rewards = sample_rewards(env=self.env,
policy=self.policy,
eval_samples=self.eval_samples,
max_path_length=self.max_path_length)
average_discounted_return = np.mean(
[discount_return(reward, self.discount) for reward in rewards])
returns = [sum(reward) for reward in rewards]
all_qs = np.concatenate(self.q_averages)
all_ys = np.concatenate(self.y_averages)
average_qfunc_loss = np.mean(self.qfunc_loss_averages)
average_policy_loss = np.mean(self.policy_loss_averages)
logger.record_tabular('Epoch', epoch)
if epoch == self.n_epochs - 1:
logger.record_tabular('AverageReturn',
np.mean(returns))
logger.record_tabular('StdReturn',
np.std(returns))
logger.record_tabular('MaxReturn',
np.max(returns))
logger.record_tabular('MinReturn',
np.min(returns))
logger.record_tabular('AverageDiscountedReturn',
average_discounted_return)
if len(self.strategy_path_returns) > 0:
logger.record_tabular('AverageEsReturn',
np.mean(self.strategy_path_returns))
logger.record_tabular('StdEsReturn',
np.std(self.strategy_path_returns))
logger.record_tabular('MaxEsReturn',
np.max(self.strategy_path_returns))
logger.record_tabular('MinEsReturn',
np.min(self.strategy_path_returns))
logger.record_tabular('AverageQLoss', average_qfunc_loss)
logger.record_tabular('AveragePolicyLoss', average_policy_loss)
logger.record_tabular('AverageQ', np.mean(all_qs))
logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
logger.record_tabular('AverageY', np.mean(all_ys))
logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
logger.record_tabular('AverageAbsQYDiff',
np.mean(np.abs(all_qs - all_ys)))
self.qfunc_loss_averages = []
self.policy_loss_averages = []
self.q_averages = []
self.y_averages = []
self.strategy_path_returns = []