from replay_mem import ReplayMem from utils import discount_return, sample_rewards import rllab.misc.logger as logger import pyprind import mxnet as mx import numpy as np class DDPG(object): def __init__( self, env, policy, qfunc, strategy, ctx=mx.gpu(0), batch_size=32, n_epochs=1000, epoch_length=1000, memory_size=1000000, memory_start_size=1000, discount=0.99, max_path_length=1000, eval_samples=10000, qfunc_updater="adam", qfunc_lr=1e-4, policy_updater="adam", policy_lr=1e-4, soft_target_tau=1e-3, n_updates_per_sample=1, include_horizon_terminal=False, seed=12345): mx.random.seed(seed) np.random.seed(seed) self.env = env self.ctx = ctx self.policy = policy self.qfunc = qfunc self.strategy = strategy self.batch_size = batch_size self.n_epochs = n_epochs self.epoch_length = epoch_length self.memory_size = memory_size self.memory_start_size = memory_start_size self.discount = discount self.max_path_length = max_path_length self.eval_samples = eval_samples self.qfunc_updater = qfunc_updater self.qfunc_lr = qfunc_lr self.policy_updater = policy_updater self.policy_lr = policy_lr self.soft_target_tau = soft_target_tau self.n_updates_per_sample = n_updates_per_sample self.include_horizon_terminal = include_horizon_terminal self.init_net() # logging self.qfunc_loss_averages = [] self.policy_loss_averages = [] self.q_averages = [] self.y_averages = [] self.strategy_path_returns = [] def init_net(self): # qfunc init qfunc_init = mx.initializer.Normal() loss_symbols = self.qfunc.get_loss_symbols() qval_sym = loss_symbols["qval"] yval_sym = loss_symbols["yval"] # define loss here loss = 1.0 / self.batch_size * mx.symbol.sum( mx.symbol.square(qval_sym - yval_sym)) qfunc_loss = loss qfunc_updater = mx.optimizer.get_updater( mx.optimizer.create(self.qfunc_updater, learning_rate=self.qfunc_lr)) self.qfunc_input_shapes = { "obs": (self.batch_size, self.env.observation_space.flat_dim), "act": (self.batch_size, self.env.action_space.flat_dim), "yval": (self.batch_size, 1)} self.qfunc.define_loss(qfunc_loss) self.qfunc.define_exe( ctx=self.ctx, init=qfunc_init, updater=qfunc_updater, input_shapes=self.qfunc_input_shapes) # qfunc_target init qfunc_target_shapes = { "obs": (self.batch_size, self.env.observation_space.flat_dim), "act": (self.batch_size, self.env.action_space.flat_dim) } self.qfunc_target = qval_sym.simple_bind(ctx=self.ctx, **qfunc_target_shapes) # parameters are not shared but initialized the same for name, arr in self.qfunc_target.arg_dict.items(): if name not in self.qfunc_input_shapes: self.qfunc.arg_dict[name].copyto(arr) # policy init policy_init = mx.initializer.Normal() loss_symbols = self.policy.get_loss_symbols() act_sym = loss_symbols["act"] policy_qval = qval_sym # note the negative one here: the loss maximizes the average return loss = -1.0 / self.batch_size * mx.symbol.sum(policy_qval) policy_loss = loss policy_loss = mx.symbol.MakeLoss(policy_loss, name="policy_loss") policy_updater = mx.optimizer.get_updater( mx.optimizer.create(self.policy_updater, learning_rate=self.policy_lr)) self.policy_input_shapes = { "obs": (self.batch_size, self.env.observation_space.flat_dim)} self.policy.define_exe( ctx=self.ctx, init=policy_init, updater=policy_updater, input_shapes=self.policy_input_shapes) # policy network and q-value network are combined to backpropage # gradients from the policy loss # since the loss is different, yval is not needed args = {} for name, arr in self.qfunc.arg_dict.items(): if name != "yval": args[name] = arr args_grad = {} policy_grad_dict = dict(zip(self.qfunc.loss.list_arguments(), self.qfunc.exe.grad_arrays)) for name, arr in policy_grad_dict.items(): if name != "yval": args_grad[name] = arr self.policy_executor = policy_loss.bind( ctx=self.ctx, args=args, args_grad=args_grad, grad_req="write") self.policy_executor_arg_dict = self.policy_executor.arg_dict self.policy_executor_grad_dict = dict(zip( policy_loss.list_arguments(), self.policy_executor.grad_arrays)) # policy_target init # target policy only needs to produce actions, not loss # parameters are not shared but initialized the same self.policy_target = act_sym.simple_bind(ctx=self.ctx, **self.policy_input_shapes) for name, arr in self.policy_target.arg_dict.items(): if name not in self.policy_input_shapes: self.policy.arg_dict[name].copyto(arr) def train(self): memory = ReplayMem( obs_dim=self.env.observation_space.flat_dim, act_dim=self.env.action_space.flat_dim, memory_size=self.memory_size) itr = 0 path_length = 0 path_return = 0 end = False obs = self.env.reset() for epoch in xrange(self.n_epochs): logger.push_prefix("epoch #%d | " % epoch) logger.log("Training started") for epoch_itr in pyprind.prog_bar(range(self.epoch_length)): # run the policy if end: # reset the environment and stretegy when an episode ends obs = self.env.reset() self.strategy.reset() # self.policy.reset() self.strategy_path_returns.append(path_return) path_length = 0 path_return = 0 # note action is sampled from the policy not the target policy act = self.strategy.get_action(obs, self.policy) nxt, rwd, end, _ = self.env.step(act) path_length += 1 path_return += rwd if not end and path_length >= self.max_path_length: end = True if self.include_horizon_terminal: memory.add_sample(obs, act, rwd, end) else: memory.add_sample(obs, act, rwd, end) obs = nxt if memory.size >= self.memory_start_size: for update_time in xrange(self.n_updates_per_sample): batch = memory.get_batch(self.batch_size) self.do_update(itr, batch) itr += 1 logger.log("Training finished") if memory.size >= self.memory_start_size: self.evaluate(epoch, memory) logger.dump_tabular(with_prefix=False) logger.pop_prefix() # self.env.terminate() # self.policy.terminate() def do_update(self, itr, batch): obss, acts, rwds, ends, nxts = batch self.policy_target.arg_dict["obs"][:] = nxts self.policy_target.forward(is_train=False) next_acts = self.policy_target.outputs[0].asnumpy() policy_acts = self.policy.get_actions(obss) self.qfunc_target.arg_dict["obs"][:] = nxts self.qfunc_target.arg_dict["act"][:] = next_acts self.qfunc_target.forward(is_train=False) next_qvals = self.qfunc_target.outputs[0].asnumpy() # executor accepts 2D tensors rwds = rwds.reshape((-1, 1)) ends = ends.reshape((-1, 1)) ys = rwds + (1.0 - ends) * self.discount * next_qvals # since policy_executor shares the grad arrays with qfunc # the update order could not be changed self.qfunc.update_params(obss, acts, ys) # in update values all computed # no need to recompute qfunc_loss and qvals qfunc_loss = self.qfunc.exe.outputs[0].asnumpy() qvals = self.qfunc.exe.outputs[1].asnumpy() self.policy_executor.arg_dict["obs"][:] = obss self.policy_executor.arg_dict["act"][:] = policy_acts self.policy_executor.forward(is_train=True) policy_loss = self.policy_executor.outputs[0].asnumpy() self.policy_executor.backward() self.policy.update_params(self.policy_executor_grad_dict["act"]) # update target networks for name, arr in self.policy_target.arg_dict.items(): if name not in self.policy_input_shapes: arr[:] = (1.0 - self.soft_target_tau) * arr[:] + \ self.soft_target_tau * self.policy.arg_dict[name][:] for name, arr in self.qfunc_target.arg_dict.items(): if name not in self.qfunc_input_shapes: arr[:] = (1.0 - self.soft_target_tau) * arr[:] + \ self.soft_target_tau * self.qfunc.arg_dict[name][:] self.qfunc_loss_averages.append(qfunc_loss) self.policy_loss_averages.append(policy_loss) self.q_averages.append(qvals) self.y_averages.append(ys) def evaluate(self, epoch, memory): if epoch == self.n_epochs - 1: logger.log("Collecting samples for evaluation") rewards = sample_rewards(env=self.env, policy=self.policy, eval_samples=self.eval_samples, max_path_length=self.max_path_length) average_discounted_return = np.mean( [discount_return(reward, self.discount) for reward in rewards]) returns = [sum(reward) for reward in rewards] all_qs = np.concatenate(self.q_averages) all_ys = np.concatenate(self.y_averages) average_qfunc_loss = np.mean(self.qfunc_loss_averages) average_policy_loss = np.mean(self.policy_loss_averages) logger.record_tabular('Epoch', epoch) if epoch == self.n_epochs - 1: logger.record_tabular('AverageReturn', np.mean(returns)) logger.record_tabular('StdReturn', np.std(returns)) logger.record_tabular('MaxReturn', np.max(returns)) logger.record_tabular('MinReturn', np.min(returns)) logger.record_tabular('AverageDiscountedReturn', average_discounted_return) if len(self.strategy_path_returns) > 0: logger.record_tabular('AverageEsReturn', np.mean(self.strategy_path_returns)) logger.record_tabular('StdEsReturn', np.std(self.strategy_path_returns)) logger.record_tabular('MaxEsReturn', np.max(self.strategy_path_returns)) logger.record_tabular('MinEsReturn', np.min(self.strategy_path_returns)) logger.record_tabular('AverageQLoss', average_qfunc_loss) logger.record_tabular('AveragePolicyLoss', average_policy_loss) logger.record_tabular('AverageQ', np.mean(all_qs)) logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs))) logger.record_tabular('AverageY', np.mean(all_ys)) logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys))) logger.record_tabular('AverageAbsQYDiff', np.mean(np.abs(all_qs - all_ys))) self.qfunc_loss_averages = [] self.policy_loss_averages = [] self.q_averages = [] self.y_averages = [] self.strategy_path_returns = []