From 99f9716b24b77949c09d6ff992369a3c4e7c5ff5 Mon Sep 17 00:00:00 2001
From: Welly Zhang
Date: Tue, 22 Nov 2016 12:57:09 +0800
Subject: [PATCH] add sampling
---
ddpg.py | 29 ++++++++++++++++++++++++++---
qfuncs.py | 13 -------------
utils.py | 39 ++++++++++++++++++++++++++++++++++++++-
3 files changed, 64 insertions(+), 17 deletions(-)
diff --git a/ddpg.py b/ddpg.py
index a138c24..d50fbce 100644
--- a/ddpg.py
+++ b/ddpg.py
@@ -1,4 +1,5 @@
from replay_mem import ReplayMem
+from utils import discount_return, sample_rewards
import rllab.misc.logger as logger
import pyprind
import mxnet as mx
@@ -21,6 +22,7 @@ def __init__(
memory_start_size=1000,
discount=0.99,
max_path_length=1000,
+ eval_samples=10000,
qfunc_updater="adam",
qfunc_lr=1e-4,
policy_updater="adam",
@@ -44,6 +46,7 @@ def __init__(
self.memory_start_size = memory_start_size
self.discount = discount
self.max_path_length = max_path_length
+ self.eval_samples = eval_samples
self.qfunc_updater = qfunc_updater
self.qfunc_lr = qfunc_lr
self.policy_updater = policy_updater
@@ -217,7 +220,6 @@ def train(self):
def do_update(self, itr, batch):
obss, acts, rwds, ends, nxts = batch
-
self.policy_target.arg_dict["obs"][:] = nxts
self.policy_target.forward(is_train=False)
@@ -228,7 +230,7 @@ def do_update(self, itr, batch):
self.qfunc_target.forward(is_train=False)
next_qvals = self.qfunc_target.outputs[0].asnumpy()
- # executor accepts tensors
+ # executor accepts 2D tensors
rwds = rwds.reshape((-1, 1))
ends = ends.reshape((-1, 1))
ys = rwds + (1.0 - ends) * self.discount * next_qvals
@@ -237,6 +239,8 @@ def do_update(self, itr, batch):
# the update order could not be changed
self.qfunc.update_params(obss, acts, ys)
+ # in update values all computed
+ # no need to recompute qfunc_loss and qvals
qfunc_loss = self.qfunc.exe.outputs[0].asnumpy()
qvals = self.qfunc.exe.outputs[1].asnumpy()
self.policy_executor.arg_dict["obs"][:] = obss
@@ -263,12 +267,29 @@ def do_update(self, itr, batch):
def evaluate(self, epoch, memory):
+ logger.log("Collecting samples for evaluation")
+ rewards = sample_rewards(policy=self.policy,
+ max_samples=self.eval_samples,
+ max_path_length=self.max_path_length)
+ average_discounted_return = np.mean(
+ [discount_return(reward, self.discount) for reward in rewards])
+ returns = [sum(reward) for reward in rewards]
+
all_qs = np.concatenate(self.q_averages)
all_ys = np.concatenate(self.y_averages)
average_qfunc_loss = np.mean(self.qfunc_loss_averages)
average_policy_loss = np.mean(self.policy_loss_averages)
+ logger.record_tabular('Epoch', epoch)
+ logger.record_tabular('AverageReturn',
+ np.mean(returns))
+ logger.record_tabular('StdReturn',
+ np.std(returns))
+ logger.record_tabular('MaxReturn',
+ np.max(returns))
+ logger.record_tabular('MinReturn',
+ np.min(returns))
if len(self.strategy_path_returns) > 0:
logger.record_tabular('AverageEsReturn',
np.mean(self.strategy_path_returns))
@@ -278,8 +299,10 @@ def evaluate(self, epoch, memory):
np.max(self.strategy_path_returns))
logger.record_tabular('MinEsReturn',
np.min(self.strategy_path_returns))
+ logger.record_tabular('AverageDiscountedReturn',
+ average_discounted_return)
logger.record_tabular('AverageQLoss', average_qfunc_loss)
- logger.record_tabular('AveragePolicySurr', average_policy_loss)
+ logger.record_tabular('AveragePolicyLoss', average_policy_loss)
logger.record_tabular('AverageQ', np.mean(all_qs))
logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
logger.record_tabular('AverageY', np.mean(all_ys))
diff --git a/qfuncs.py b/qfuncs.py
index 7247c34..af9bdb6 100644
--- a/qfuncs.py
+++ b/qfuncs.py
@@ -62,19 +62,6 @@ def define_exe(self, ctx, init, updater, input_shapes=None, args=None,
self.updater = updater
- """
- # define an executor for qval only
- # used for q values without computing the loss
- # note the parameters are shared
- args = {}
- for name, arr in self.exe.arg_dict.items():
- if name in self.qval.list_arguments():
- args[name] = arr
- self.exe_qval = self.qval.bind(ctx=ctx,
- args=args,
- grad_req="null")
- """
-
def update_params(self, obs, act, yval):
self.arg_dict["obs"][:] = obs
diff --git a/utils.py b/utils.py
index 2a5f62f..3a7153d 100644
--- a/utils.py
+++ b/utils.py
@@ -1,4 +1,5 @@
import mxnet as mx
+import numpy as np
# random seed for reproduction
SEED = 12345
@@ -66,4 +67,40 @@ def define_policy(obs, action_dim):
name="act",
act_type="tanh")
- return action
\ No newline at end of file
+ return action
+
+
+def discount_return(x, discount):
+
+ return np.sum(x * (discount ** np.arange(len(x))))
+
+
+def rollout(env, agent, max_path_length=np.inf):
+
+ reward = []
+ o = env.reset()
+ # agent.reset()
+ path_length = 0
+ while path_length < max_path_length:
+ a = agent.get_action(o)
+ next_o, r, d, _ = env.step(a)
+ reward.append(r)
+ path_length += 1
+ if d:
+ break
+ o = next_o
+
+ return reward
+
+
+def sample_rewards(env, policy, eval_samples, max_path_length=np.inf):
+
+ rewards = []
+ for _ in eval_samples:
+ rewards.append(rollout(env, policy, max_path_length))
+
+ return rewards
+
+
+
+