Skip to content

Commit

Permalink
add sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
WellyZhang committed Nov 22, 2016
1 parent 63af7ce commit 99f9716
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
29 changes: 26 additions & 3 deletions ddpg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from replay_mem import ReplayMem
from utils import discount_return, sample_rewards
import rllab.misc.logger as logger
import pyprind
import mxnet as mx
Expand All @@ -21,6 +22,7 @@ def __init__(
memory_start_size=1000,
discount=0.99,
max_path_length=1000,
eval_samples=10000,
qfunc_updater="adam",
qfunc_lr=1e-4,
policy_updater="adam",
Expand All @@ -44,6 +46,7 @@ def __init__(
self.memory_start_size = memory_start_size
self.discount = discount
self.max_path_length = max_path_length
self.eval_samples = eval_samples
self.qfunc_updater = qfunc_updater
self.qfunc_lr = qfunc_lr
self.policy_updater = policy_updater
Expand Down Expand Up @@ -217,7 +220,6 @@ def train(self):
def do_update(self, itr, batch):

obss, acts, rwds, ends, nxts = batch


self.policy_target.arg_dict["obs"][:] = nxts
self.policy_target.forward(is_train=False)
Expand All @@ -228,7 +230,7 @@ def do_update(self, itr, batch):
self.qfunc_target.forward(is_train=False)
next_qvals = self.qfunc_target.outputs[0].asnumpy()

# executor accepts tensors
# executor accepts 2D tensors
rwds = rwds.reshape((-1, 1))
ends = ends.reshape((-1, 1))
ys = rwds + (1.0 - ends) * self.discount * next_qvals
Expand All @@ -237,6 +239,8 @@ def do_update(self, itr, batch):
# the update order could not be changed

self.qfunc.update_params(obss, acts, ys)
# in update values all computed
# no need to recompute qfunc_loss and qvals
qfunc_loss = self.qfunc.exe.outputs[0].asnumpy()
qvals = self.qfunc.exe.outputs[1].asnumpy()
self.policy_executor.arg_dict["obs"][:] = obss
Expand All @@ -263,12 +267,29 @@ def do_update(self, itr, batch):

def evaluate(self, epoch, memory):

logger.log("Collecting samples for evaluation")
rewards = sample_rewards(policy=self.policy,
max_samples=self.eval_samples,
max_path_length=self.max_path_length)
average_discounted_return = np.mean(
[discount_return(reward, self.discount) for reward in rewards])
returns = [sum(reward) for reward in rewards]

all_qs = np.concatenate(self.q_averages)
all_ys = np.concatenate(self.y_averages)

average_qfunc_loss = np.mean(self.qfunc_loss_averages)
average_policy_loss = np.mean(self.policy_loss_averages)

logger.record_tabular('Epoch', epoch)
logger.record_tabular('AverageReturn',
np.mean(returns))
logger.record_tabular('StdReturn',
np.std(returns))
logger.record_tabular('MaxReturn',
np.max(returns))
logger.record_tabular('MinReturn',
np.min(returns))
if len(self.strategy_path_returns) > 0:
logger.record_tabular('AverageEsReturn',
np.mean(self.strategy_path_returns))
Expand All @@ -278,8 +299,10 @@ def evaluate(self, epoch, memory):
np.max(self.strategy_path_returns))
logger.record_tabular('MinEsReturn',
np.min(self.strategy_path_returns))
logger.record_tabular('AverageDiscountedReturn',
average_discounted_return)
logger.record_tabular('AverageQLoss', average_qfunc_loss)
logger.record_tabular('AveragePolicySurr', average_policy_loss)
logger.record_tabular('AveragePolicyLoss', average_policy_loss)
logger.record_tabular('AverageQ', np.mean(all_qs))
logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
logger.record_tabular('AverageY', np.mean(all_ys))
Expand Down
13 changes: 0 additions & 13 deletions qfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,6 @@ def define_exe(self, ctx, init, updater, input_shapes=None, args=None,

self.updater = updater

"""
# define an executor for qval only
# used for q values without computing the loss
# note the parameters are shared
args = {}
for name, arr in self.exe.arg_dict.items():
if name in self.qval.list_arguments():
args[name] = arr
self.exe_qval = self.qval.bind(ctx=ctx,
args=args,
grad_req="null")
"""

def update_params(self, obs, act, yval):

self.arg_dict["obs"][:] = obs
Expand Down
39 changes: 38 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import mxnet as mx
import numpy as np

# random seed for reproduction
SEED = 12345
Expand Down Expand Up @@ -66,4 +67,40 @@ def define_policy(obs, action_dim):
name="act",
act_type="tanh")

return action
return action


def discount_return(x, discount):

return np.sum(x * (discount ** np.arange(len(x))))


def rollout(env, agent, max_path_length=np.inf):

reward = []
o = env.reset()
# agent.reset()
path_length = 0
while path_length < max_path_length:
a = agent.get_action(o)
next_o, r, d, _ = env.step(a)
reward.append(r)
path_length += 1
if d:
break
o = next_o

return reward


def sample_rewards(env, policy, eval_samples, max_path_length=np.inf):

rewards = []
for _ in eval_samples:
rewards.append(rollout(env, policy, max_path_length))

return rewards




0 comments on commit 99f9716

Please sign in to comment.