Skip to content

Commit

Permalink
replace exec
Browse files Browse the repository at this point in the history
  • Loading branch information
WellyZhang committed Nov 14, 2016
1 parent d68de79 commit c57b0f5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
8 changes: 4 additions & 4 deletions ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def init_net(self):
act=(self.batch_size, self.env.action_space.flat_dim),
yval=(self.batch_size, )}
self.qfunc.define_loss(qfunc_loss)
self.qfunc.define_exec(
self.qfunc.define_exe(
ctx=self.ctx,
init=qfunc_init,
updater=qfunc_updater,
Expand Down Expand Up @@ -125,7 +125,7 @@ def init_net(self):
learning_rate=self.policy_lr))
self.policy_input_shapes = {
obs=(self.batch_size, self.env.observation_space.flat_dim)}
self.policy.define_exec(
self.policy.define_exe(
ctx=self.ctx,
init=policy_init,
updater=policy_updater,
Expand Down Expand Up @@ -205,8 +205,8 @@ def do_update(self, itr, batch):
ys = rwds + (1.0 - ends) * self.discount * next_qvals

self.qfunc.update_params(obss, acts, ys)
qfunc_loss = self.qfunc.exec.outputs[0].asnumpy()
qvals = self.qfunc.exec.outputs[1].asnumpy()
qfunc_loss = self.qfunc.exe.outputs[0].asnumpy()
qvals = self.qfunc.exe.outputs[1].asnumpy()
self.policy_executor.arg_dict["obs"][:] = obss
self.policy_executor.arg_dict["act"][:] = policy_acts
self.policy_executor.forward(is_train=True)
Expand Down
24 changes: 12 additions & 12 deletions policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def define_loss(self, loss_exp):

raise NotImplementedError

def define_exec(self, ctx, init, updater, input_shapes=None, args=None,
def define_exe(self, ctx, init, updater, input_shapes=None, args=None,
grad_req=None):

self.exec = self.act.simple_bind(ctx=ctx, **input_shapes)
self.arg_arrays = self.exec.arg_arrays
self.grad_arrays = self.exec.grad_arrays
self.arg_dict = self.exec.arg_dict
self.exe = self.act.simple_bind(ctx=ctx, **input_shapes)
self.arg_arrays = self.exe.arg_arrays
self.grad_arrays = self.exe.grad_arrays
self.arg_dict = self.exe.arg_dict

for name, arr in self.arg_dict.items():
if name not in input_shapes:
Expand All @@ -65,12 +65,12 @@ def define_exec(self, ctx, init, updater, input_shapes=None, args=None,
self.updater = updater

new_input_shapes = {obs=(1, input_shapes["obs"][1])}
self.exec_one = self.exec.reshape(**new_input_shapes)
self.exe_one = self.exe.reshape(**new_input_shapes)

def update_params(self, grad_from_top):

self.exec.forward(is_train=True)
self.exec.backward([grad_from_top])
self.exe.forward(is_train=True)
self.exe.backward([grad_from_top])

for i, pair in enumerate(zip(self.arg_arrays, self.grad_arrays)):
weight, grad = pair
Expand All @@ -79,16 +79,16 @@ def update_params(self, grad_from_top):
def get_actions(self, obs):

self.arg_dict["obs"][:] = obs
self.exec.forward(is_train=False)
self.exe.forward(is_train=False)

return self.exec.outputs[0].asnumpy()
return self.exe.outputs[0].asnumpy()

def get_action(self, obs):

self.arg_dict_one["obs"][:] = obs
self.exec_one.forward(is_train=False)
self.exe_one.forward(is_train=False)

return self.exec_one.outputs[0].asnumpy()
return self.exe_one.outputs[0].asnumpy()



Expand Down
18 changes: 9 additions & 9 deletions qfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def define_loss(self, loss_exp):
self.loss = mx.symbol.MakeLoss(loss_exp, name="qfunc_loss")
self.loss = mx.symbol.Group([self.loss, self.qval])

def define_exec(self, ctx, init, updater, input_shapes=None, args=None,
def define_exe(self, ctx, init, updater, input_shapes=None, args=None,
grad_req=None):

self.exec = self.loss.simple_bind(ctx=ctx, **input_shapes)
self.arg_arrays = self.exec.arg_arrays
self.grad_arrays = self.exec.grad_arrays
self.arg_dict = self.exec.arg_dict
self.exe = self.loss.simple_bind(ctx=ctx, **input_shapes)
self.arg_arrays = self.exe.arg_arrays
self.grad_arrays = self.exe.grad_arrays
self.arg_dict = self.exe.arg_dict

for name, arr in self.arg_dict.items():
if name not in input_shapes:
Expand All @@ -61,8 +61,8 @@ def update_params(self, obs, act, yval):
self.arg_dict["act"][:] = act
self.arg_dict["yval"][:] = yval

self.exec.forward(is_train=True)
self.exec.backward()
self.exe.forward(is_train=True)
self.exe.backward()

for i, pair in enumerate(zip(self.arg_arrays, self.grad_arrays)):
weight, grad = pair
Expand All @@ -72,8 +72,8 @@ def get_qvals(self, obs, act):

self.arg_dict["obs"][:] = obs
self.arg_dict["act"][:] = act
self.exec.forward(is_train=False)
self.exe.forward(is_train=False)

return self.exec.outputs[1].asnumpy()
return self.exe.outputs[1].asnumpy()


0 comments on commit c57b0f5

Please sign in to comment.