diff --git a/advertorch_examples/attack_madry_et_al_models/madry_et_al_utils.py b/advertorch_examples/attack_madry_et_al_models/madry_et_al_utils.py index 8c79da8..abfb4c3 100644 --- a/advertorch_examples/attack_madry_et_al_models/madry_et_al_utils.py +++ b/advertorch_examples/attack_madry_et_al_models/madry_et_al_utils.py @@ -46,7 +46,7 @@ def __init__(self, weights_path, model_class): bw_loss = tf.reduce_sum(self.logits * self.bw_gradient_pre) self.bw_gradients = tf.gradients(bw_loss, self.inputs)[0] - def backward(self, logits_grad_val, inputs_val): + def backward(self, inputs_val, logits_grad_val): inputs_grad_val = self.session.run( self.bw_gradients, feed_dict={ @@ -80,10 +80,10 @@ def forward(self, inputs_val): rval = self.tfmodel.forward(self._to_numpy(inputs_val)) return self._to_torch(rval) - def backward(self, logits_grad_val, inputs_val): + def backward(self, inputs_val, logits_grad_val): rval = self.tfmodel.backward( - self._to_numpy(logits_grad_val), self._to_numpy(inputs_val), + self._to_numpy(logits_grad_val), ) return self._to_torch(rval) @@ -91,7 +91,6 @@ def backward(self, logits_grad_val, inputs_val): def get_madry_et_al_tf_model(dataname, device="cuda"): if dataname == "MNIST": - # XXX: weights_path = os.path.join( MODEL_PATH, 'mnist_challenge/models/secret') @@ -112,7 +111,6 @@ def _process_grads_val(val): elif dataname == "CIFAR10": - # XXX: weights_path = os.path.join( MODEL_PATH, 'cifar10_challenge/models/model_0') @@ -147,7 +145,7 @@ def new_forward(inputs_val): def _wrap_backward(backward): def new_backward(inputs_val, logits_grad_val): return _process_grads_val(backward( - *logits_grad_val, _process_inputs_val(*inputs_val))) + _process_inputs_val(*inputs_val), *logits_grad_val)) return new_backward