Skip to content

Commit

Permalink
rearrangements
Browse files Browse the repository at this point in the history
  • Loading branch information
gwding committed Mar 4, 2020
1 parent 92ab584 commit d8def16
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, weights_path, model_class):
bw_loss = tf.reduce_sum(self.logits * self.bw_gradient_pre)
self.bw_gradients = tf.gradients(bw_loss, self.inputs)[0]

def backward(self, logits_grad_val, inputs_val):
def backward(self, inputs_val, logits_grad_val):
inputs_grad_val = self.session.run(
self.bw_gradients,
feed_dict={
Expand Down Expand Up @@ -80,18 +80,17 @@ def forward(self, inputs_val):
rval = self.tfmodel.forward(self._to_numpy(inputs_val))
return self._to_torch(rval)

def backward(self, logits_grad_val, inputs_val):
def backward(self, inputs_val, logits_grad_val):
rval = self.tfmodel.backward(
self._to_numpy(logits_grad_val),
self._to_numpy(inputs_val),
self._to_numpy(logits_grad_val),
)
return self._to_torch(rval)



def get_madry_et_al_tf_model(dataname, device="cuda"):
if dataname == "MNIST":
# XXX:
weights_path = os.path.join(
MODEL_PATH, 'mnist_challenge/models/secret')

Expand All @@ -112,7 +111,6 @@ def _process_grads_val(val):


elif dataname == "CIFAR10":
# XXX:
weights_path = os.path.join(
MODEL_PATH, 'cifar10_challenge/models/model_0')

Expand Down Expand Up @@ -147,7 +145,7 @@ def new_forward(inputs_val):
def _wrap_backward(backward):
def new_backward(inputs_val, logits_grad_val):
return _process_grads_val(backward(
*logits_grad_val, _process_inputs_val(*inputs_val)))
_process_inputs_val(*inputs_val), *logits_grad_val))
return new_backward


Expand Down

0 comments on commit d8def16

Please sign in to comment.