Skip to content

Commit

Permalink
Add NESAttack and refactor Bandit
Browse files Browse the repository at this point in the history
  • Loading branch information
CaesarQ committed Mar 31, 2022
1 parent 2d88fcb commit 71e3690
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 56 deletions.
9 changes: 7 additions & 2 deletions advertorch/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
from .blackbox.gen_attack import LinfGenAttack
from .blackbox.gen_attack import L2GenAttack

from .blackbox.grad_estimators import FDWrapper, NESWrapper
from .blackbox.nattack import NAttack
from .blackbox.bandits_t import BanditAttack
from .blackbox.nattack import LinfNAttack
from .blackbox.nattack import L2NAttack

from .blackbox.estimators import FDWrapper, NESWrapper

from .blackbox.bandits_t import BanditAttack
from .blackbox.iterative_gradient_approximation import NESAttack
15 changes: 11 additions & 4 deletions advertorch/attacks/blackbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
from .gen_attack import LinfGenAttack
from .gen_attack import L2GenAttack

from .grad_estimators import GradientWrapper
from .grad_estimators import FDWrapper, NESWrapper

from .nattack import NAttack
from .bandits_t import BanditAttack
from .nattack import LinfNAttack
from .nattack import L2NAttack

from .estimators import GradientWrapper
from .estimators import FDWrapper, NESWrapper

from .bandits_t import BanditAttack

from .iterative_gradient_approximation import NESAttack

from .utils import pytorch_wrapper
56 changes: 28 additions & 28 deletions advertorch/attacks/blackbox/bandits_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from math import inf
from typing import Optional

import numpy as np
Expand All @@ -16,12 +16,12 @@
from advertorch.attacks.base import Attack
from advertorch.attacks.base import LabelMixin

from advertorch.attacks.blackbox.utils import _check_param, _flatten
from .utils import _check_param, _flatten, _make_projector

def bandit_attack(
x, loss_fn, eps, clip_min, clip_max, order='l2',
delta_init=None, prior_init=None, fd_eta=0.01, exploration=0.01,
online_lr=0.1, nb_iter=40, eps_iter=0.01
x, loss_fn, eps, order, projector, delta_init=None, prior_init=None,
fd_eta=0.01, exploration=0.01, online_lr=0.1, nb_iter=40,
eps_iter=0.01
):
"""
Performs the BanditAttack
Expand All @@ -30,9 +30,9 @@ def bandit_attack(
:param x: input data.
:param loss_fn: loss function.
:param eps: maximum distortion.
:param order: (optional) the order of maximum distortion ('linf' or 'l2').
:param clip_min: mininum value per input dimension (default 0.)
:param clip_max: mininum value per input dimension (default 1.)
:param order: (optional) the order of maximum distortion (2 or inf).
:param projector: function to project the perturbation into the eps-ball
- must accept tensors of shape [nbatch, pop_size, ndim]
:param delta_init: (default None)
:param prior_init: (default None)
:param fd_eta: step-size used for finite difference grad estimate (default 0.01)
Expand Down Expand Up @@ -74,21 +74,17 @@ def bandit_attack(

delta_L = (L1 - L2)/(fd_eta * exploration) #[nbatch]

grad_est = delta_L * exp_noise
if order == 'l2':
grad_est = delta_L[:, None] * exp_noise
if order == 2:
#update prior
prior = prior + online_lr * grad_est
#make step with prior
#note the (+): this indicates gradient ascent on the loss
adv = adv + eps_iter * F.normalize(prior, dim=-1)
#project
delta = adv - x
norms = torch.sqrt( (delta ** 2).sum(-1))
out_of_bounds_mask = (norms > eps).float().unsqueeze(-1)
#if out_of_bounds, use the clipped values
clipped = (x + eps[:, None] * F.normalize(delta, dim=-1))
adv = clipped * out_of_bounds_mask + adv * (1 - out_of_bounds_mask)
elif order == 'linf':
delta = projector(delta[:, None, :]).squeeze(1)
elif order == inf:
#update prior (exponentiated gradients)
prior = (prior + 1) / 2 # from [-1, 1] to [0, 1]
pos = prior * torch.exp(online_lr * grad_est)
Expand All @@ -97,15 +93,13 @@ def bandit_attack(
#make step with prior
adv = adv + eps_iter * torch.sign(prior)
#project
delta = torch.minimum(adv - x, eps[:, None])
delta = torch.maximum(delta, -eps[:, None])
adv = x + delta
delta = adv - x
delta = projector(delta[:, None, :]).squeeze(1)
else:
error = "Only order = 'linf', order = 'l2' have been implemented"
error = "Only order=inf, order=2 have been implemented"
raise NotImplementedError(error)

adv = torch.maximum(adv, clip_min)
adv = torch.minimum(adv, clip_max)
adv = x + delta

return adv, prior

Expand All @@ -117,7 +111,7 @@ class BanditAttack(Attack, LabelMixin):
:param predict: forward pass function.
:param eps: maximum distortion.
:param order: the order of maximum distortion ('linf' or 'l2', default l2)
:param order: the order of maximum distortion (inf or 2)
:param fd_eta: step-size used for finite difference grad estimate (default 0.01)
:param exploration: scales the exploration around prior (default 0.01)
:param online_lr: learning rate for the prior (default 0.1)
Expand Down Expand Up @@ -169,24 +163,30 @@ def perturb( # type: ignore
"""
x, y = self._verify_and_process_inputs(x, y)
shape, flat_x = _flatten(x)
data_shape = tuple(shape[1:])

eps = _check_param(self.eps, x.new_full((x.shape[0],), 1), 'eps')
clip_min = _check_param(self.clip_min, flat_x, 'clip_min')
clip_max = _check_param(self.clip_max, flat_x, 'clip_max')

projector = _make_projector(
eps, self.order, flat_x, clip_min, clip_max
)

scale = -1 if self.targeted else 1
def L(x): #loss func
#new_shape = (x.shape[0],) + data_shape
#input = x.reshape(new_shape)
input = x.reshape(shape)
output = self.predict(input)
loss = scale * self.loss_fn(output, y)
return loss

adv, _ = bandit_attack(
flat_x, loss_fn=L, eps=eps, order=self.order, clip_min=clip_min,
clip_max=clip_max, delta_init=None, prior_init=None,
fd_eta=self.fd_eta, exploration=self.exploration,
online_lr=self.online_lr, nb_iter=self.nb_iter,
eps_iter=self.eps_iter
flat_x, loss_fn=L, eps=eps, order=self.order, projector=projector,
delta_init=None, prior_init=None, fd_eta=self.fd_eta,
exploration=self.exploration, online_lr=self.online_lr,
nb_iter=self.nb_iter, eps_iter=self.eps_iter
)

adv = adv.reshape(shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,28 @@
#

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

def pytorch_wrapper(func):
def wrapped_func(x):
x_numpy = x.cpu().data.numpy()
output = func(x_numpy)
output = torch.from_numpy(output)
output = output.to(x.device)

return output

return wrapped_func

def norm(v):
return torch.sqrt( (v ** 2).sum(-1) )

class GradientWrapper(torch.nn.Module):
#facility for doing things in batch?
"""
Define a backward pass for a blackbox function using extra queries.
Once wrapped, the blackbox function will become compatible with any attack
in Advertorch, so long as self.training is True.
Disclaimer: This wrapper assumes inputs will have shape [nbatch, ndim].
For models that operate on images, you will need to wrap the function
inside a reshaper. See NESAttack for an example.
:param func: A blackbox function.
- This function must accept, and output, torch tensors.
"""
def __init__(self, func):
super().__init__()
self.func = pytorch_wrapper(func)
self.func = func

#Based on:
#https://pytorch.org/docs/stable/notes/extending.html
Expand All @@ -44,7 +42,7 @@ def forward(ctx, input):

@staticmethod
def backward(ctx, grad_output):
#TODO: this is not general! May not work for images
#Note: this is not general! May not work for images
#Be careful about dimensions
grad_est, = ctx.saved_tensors
grad_input = None
Expand All @@ -57,7 +55,9 @@ def backward(ctx, grad_output):
self.diff_func = _Func.apply

def batch_query(self, x):
#TODO: accomodate images...
"""
Reshapes the queries for efficient, parallel estimation.
"""
n_batch, n_dim, nb_samples = x.shape
x = x.permute(0, 2, 1).reshape(-1, n_dim)
outputs = self.func(x) #shape [..., n_output]
Expand All @@ -69,7 +69,6 @@ def estimate_grad(self, x):
raise NotImplementedError

def forward(self, x):
#TODO: check compatibility with torch.no_grad()
if not self.training:
output = self.func(x)
else:
Expand All @@ -79,7 +78,13 @@ def forward(self, x):

class FDWrapper(GradientWrapper):
"""
Finite-Difference Estimator
Finite-Difference Estimator.
For every backward pass, this module makes 2 * n_dim queries per
instance.
:param func: A blackbox function.
- This function must accept, and output, torch tensors.
:param fd_eta: Step-size used for the finite-difference estimation.
"""
def __init__(self, func, fd_eta=1e-3):
super().__init__(func)
Expand All @@ -101,14 +106,22 @@ def estimate_grad(self, x):


class NESWrapper(GradientWrapper):
#NES Attack: https://arxiv.org/pdf/1804.08598.pdf
"""
Natural-evolutionary strategy for gradient estimation.
For every backward pass, this module makes 2 * nb_samples
queries per instance.
:param func: A blackbox function.
- This function must accept, and output, torch tensors.
:param nb_samples: Number of samples to use in the grad estimation.
:param fd_eta: Step-size used for the finite-difference estimation.
"""
def __init__(self, func, nb_samples, fd_eta=1e-3):
super().__init__(func)
self.nb_samples = nb_samples
self.fd_eta = fd_eta

def estimate_grad(self, x, prior=None):
#TODO: adjust this so that it works with images...
#x shape: [nbatch, ndim]
ndim = np.prod(list(x.shape[1:]))

Expand Down
2 changes: 1 addition & 1 deletion advertorch/attacks/blackbox/gen_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def perturb( # type: ignore
x, y = self._verify_and_process_inputs(x, y)
shape, flat_x = _flatten(x)
data_shape = tuple(shape[1:])

#[B]
eps = _check_param(self.eps, x.new_full((x.shape[0],), 1), 'eps')
#[B, F]
Expand Down
103 changes: 103 additions & 0 deletions advertorch/attacks/blackbox/iterative_gradient_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2018-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn

from advertorch.utils import clamp, to_one_hot, is_float_or_torch_tensor

from advertorch.attacks.utils import rand_init_delta

from advertorch.attacks.base import Attack
from advertorch.attacks.base import LabelMixin

from advertorch.attacks.iterative_projected_gradient import LinfPGDAttack
from advertorch.attacks.iterative_projected_gradient import perturb_iterative

from .estimators import NESWrapper
from .utils import _flatten

class NESAttack(LinfPGDAttack):
"""
Implements NES Attack https://arxiv.org/abs/1804.08598
Employs Natural Evolutionary Strategies for Gradient Estimation.
Generates Adversarial Examples using Projected Gradient Descent.
Disclaimer: Computations are broadcasted, so it is advisable to use
smaller batch sizes when nb_samples is large.
:param predict: forward pass function.
:param loss_fn: loss function.
:param eps: maximum distortion.
:param nb_samples: number of samples to use for gradient estimation
:param fd_eta: step-size used for Finite Difference gradient estimation
:param nb_iter: number of iterations.
:param eps_iter: attack step size.
:param rand_init: (optional bool) random initialization.
:param clip_min: mininum value per input dimension.
:param clip_max: maximum value per input dimension.
:param targeted: if the attack is targeted.
"""

def __init__(
self, predict, loss_fn=None, eps=0.3,
nb_samples=100, fd_eta=1e-2, nb_iter=40,
eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
targeted=False):

super(NESAttack, self).__init__(
predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter,
eps_iter=eps_iter, rand_init=rand_init, clip_min=clip_min,
clip_max=clip_max, targeted=targeted)

self.nb_samples = nb_samples
self.fd_eta = fd_eta

def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with
an attack length of eps.
:param x: input tensor.
:param y: label tensor.
- if None and self.targeted=False, compute y as predicted
labels.
- if self.targeted=True, then y must be the targeted labels.
:return: tensor containing perturbed inputs.
"""
x, y = self._verify_and_process_inputs(x, y)
shape, flat_x = _flatten(x)
data_shape = tuple(shape[1:])
def f(x):
new_shape = (x.shape[0],) + data_shape
input = x.reshape(new_shape)
return self.predict(input)
f_nes = NESWrapper(
f, nb_samples=self.nb_samples, fd_eta=self.fd_eta
)

delta = torch.zeros_like(flat_x)
delta = nn.Parameter(delta)
if self.rand_init:
rand_init_delta(
delta, flat_x, self.ord, self.eps, self.clip_min, self.clip_max
)
delta.data = clamp(
flat_x + delta.data, min=self.clip_min, max=self.clip_max
) - flat_x

rval = perturb_iterative(
flat_x, y, f_nes, nb_iter=self.nb_iter,
eps=self.eps, eps_iter=self.eps_iter,
loss_fn=self.loss_fn, minimize=self.targeted,
ord=self.ord, clip_min=self.clip_min,
clip_max=self.clip_max, delta_init=delta,
l1_sparsity=None
)

return rval.data.reshape(shape)
Loading

0 comments on commit 71e3690

Please sign in to comment.