From 99f9716b24b77949c09d6ff992369a3c4e7c5ff5 Mon Sep 17 00:00:00 2001 From: Welly Zhang Date: Tue, 22 Nov 2016 12:57:09 +0800 Subject: [PATCH] add sampling --- ddpg.py | 29 ++++++++++++++++++++++++++--- qfuncs.py | 13 ------------- utils.py | 39 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/ddpg.py b/ddpg.py index a138c24..d50fbce 100644 --- a/ddpg.py +++ b/ddpg.py @@ -1,4 +1,5 @@ from replay_mem import ReplayMem +from utils import discount_return, sample_rewards import rllab.misc.logger as logger import pyprind import mxnet as mx @@ -21,6 +22,7 @@ def __init__( memory_start_size=1000, discount=0.99, max_path_length=1000, + eval_samples=10000, qfunc_updater="adam", qfunc_lr=1e-4, policy_updater="adam", @@ -44,6 +46,7 @@ def __init__( self.memory_start_size = memory_start_size self.discount = discount self.max_path_length = max_path_length + self.eval_samples = eval_samples self.qfunc_updater = qfunc_updater self.qfunc_lr = qfunc_lr self.policy_updater = policy_updater @@ -217,7 +220,6 @@ def train(self): def do_update(self, itr, batch): obss, acts, rwds, ends, nxts = batch - self.policy_target.arg_dict["obs"][:] = nxts self.policy_target.forward(is_train=False) @@ -228,7 +230,7 @@ def do_update(self, itr, batch): self.qfunc_target.forward(is_train=False) next_qvals = self.qfunc_target.outputs[0].asnumpy() - # executor accepts tensors + # executor accepts 2D tensors rwds = rwds.reshape((-1, 1)) ends = ends.reshape((-1, 1)) ys = rwds + (1.0 - ends) * self.discount * next_qvals @@ -237,6 +239,8 @@ def do_update(self, itr, batch): # the update order could not be changed self.qfunc.update_params(obss, acts, ys) + # in update values all computed + # no need to recompute qfunc_loss and qvals qfunc_loss = self.qfunc.exe.outputs[0].asnumpy() qvals = self.qfunc.exe.outputs[1].asnumpy() self.policy_executor.arg_dict["obs"][:] = obss @@ -263,12 +267,29 @@ def do_update(self, itr, batch): def evaluate(self, epoch, memory): + logger.log("Collecting samples for evaluation") + rewards = sample_rewards(policy=self.policy, + max_samples=self.eval_samples, + max_path_length=self.max_path_length) + average_discounted_return = np.mean( + [discount_return(reward, self.discount) for reward in rewards]) + returns = [sum(reward) for reward in rewards] + all_qs = np.concatenate(self.q_averages) all_ys = np.concatenate(self.y_averages) average_qfunc_loss = np.mean(self.qfunc_loss_averages) average_policy_loss = np.mean(self.policy_loss_averages) + logger.record_tabular('Epoch', epoch) + logger.record_tabular('AverageReturn', + np.mean(returns)) + logger.record_tabular('StdReturn', + np.std(returns)) + logger.record_tabular('MaxReturn', + np.max(returns)) + logger.record_tabular('MinReturn', + np.min(returns)) if len(self.strategy_path_returns) > 0: logger.record_tabular('AverageEsReturn', np.mean(self.strategy_path_returns)) @@ -278,8 +299,10 @@ def evaluate(self, epoch, memory): np.max(self.strategy_path_returns)) logger.record_tabular('MinEsReturn', np.min(self.strategy_path_returns)) + logger.record_tabular('AverageDiscountedReturn', + average_discounted_return) logger.record_tabular('AverageQLoss', average_qfunc_loss) - logger.record_tabular('AveragePolicySurr', average_policy_loss) + logger.record_tabular('AveragePolicyLoss', average_policy_loss) logger.record_tabular('AverageQ', np.mean(all_qs)) logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs))) logger.record_tabular('AverageY', np.mean(all_ys)) diff --git a/qfuncs.py b/qfuncs.py index 7247c34..af9bdb6 100644 --- a/qfuncs.py +++ b/qfuncs.py @@ -62,19 +62,6 @@ def define_exe(self, ctx, init, updater, input_shapes=None, args=None, self.updater = updater - """ - # define an executor for qval only - # used for q values without computing the loss - # note the parameters are shared - args = {} - for name, arr in self.exe.arg_dict.items(): - if name in self.qval.list_arguments(): - args[name] = arr - self.exe_qval = self.qval.bind(ctx=ctx, - args=args, - grad_req="null") - """ - def update_params(self, obs, act, yval): self.arg_dict["obs"][:] = obs diff --git a/utils.py b/utils.py index 2a5f62f..3a7153d 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import mxnet as mx +import numpy as np # random seed for reproduction SEED = 12345 @@ -66,4 +67,40 @@ def define_policy(obs, action_dim): name="act", act_type="tanh") - return action \ No newline at end of file + return action + + +def discount_return(x, discount): + + return np.sum(x * (discount ** np.arange(len(x)))) + + +def rollout(env, agent, max_path_length=np.inf): + + reward = [] + o = env.reset() + # agent.reset() + path_length = 0 + while path_length < max_path_length: + a = agent.get_action(o) + next_o, r, d, _ = env.step(a) + reward.append(r) + path_length += 1 + if d: + break + o = next_o + + return reward + + +def sample_rewards(env, policy, eval_samples, max_path_length=np.inf): + + rewards = [] + for _ in eval_samples: + rewards.append(rollout(env, policy, max_path_length)) + + return rewards + + + +