Skip to content

Commit

Permalink
fix bandit bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
CaesarQ committed Jul 22, 2021
1 parent bb2e010 commit c6cdc00
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
3 changes: 2 additions & 1 deletion advertorch/attacks/blackbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .grad_estimators import GradientWrapper
from .grad_estimators import FDWrapper, NESWrapper

from .nattack import NAttack
from .nattack import NAttack
from .iterative_projected_gradient import BanditAttack
33 changes: 21 additions & 12 deletions advertorch/attacks/blackbox/iterative_projected_gradient.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Tuple, Union

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -19,21 +21,25 @@ def eg_step(x, g, lr):
new_x = pos/(pos+neg)
return new_x*2-1


def norm(x):
return torch.sqrt( (x ** 2).sum(-1))

def gd_prior_step(x, g, lr):
return x + lr*g

def l2_image_step(x, g, lr):
return x + lr*g/norm(g)
def l2_data_step(x, g, lr):
return x + lr*F.normalize(g, dim=-1)

def linf_step(x, g, lr):
return x + lr*ch.sign(g)
return x + lr*torch.sign(g)

def l2_proj(image, eps):
orig = image.clone()
def proj(new_x):
delta = new_x - orig
out_of_bounds_mask = (norm(delta) > eps).float()
x = (orig + eps*delta/norm(delta))*out_of_bounds_mask
out_of_bounds_mask = (norm(delta) > eps).float().unsqueeze(-1)
x = (orig + eps[:, None] * F.normalize(delta, dim=-1))*out_of_bounds_mask
x += new_x*(1-out_of_bounds_mask)
return x
return proj
Expand All @@ -49,13 +55,14 @@ def proj(new_x):
#https://github.com/MadryLab/blackbox-bandits/blob/master/src/main.py
class BanditAttack(Attack, LabelMixin):
def __init__(
self, predict, eps: float,
fd_eta, exploration, online_lr, order,
self, predict, eps: float, order,
fd_eta=0.01, exploration=0.01, online_lr=0.1,
loss_fn=None,
nb_iter=40,
eps_iter=0.01,
clip_min=0., clip_max=1.,
targeted : bool = False
targeted : bool = False,
query_limit=None
):

super().__init__(predict, loss_fn, clip_min, clip_max)
Expand All @@ -74,6 +81,8 @@ def __init__(

self.proj_maker = l2_proj if order == 'l2' else linf_proj

self.query_limit = None

def perturb( # type: ignore
self,
x: torch.FloatTensor,
Expand All @@ -93,7 +102,7 @@ def perturb( # type: ignore
x, y = self._verify_and_process_inputs(x, y)
x_adv = x.clone()

eps = _check_param(self.eps, x.new_full((x.shape[0], )), 1, 'eps')
eps = _check_param(self.eps, x.new_full((x.shape[0],), 1), 'eps')
clip_min = _check_param(self.clip_min, x, 'clip_min')
clip_max = _check_param(self.clip_max, x, 'clip_max')
#sample using mean param
Expand Down Expand Up @@ -124,7 +133,7 @@ def L(x): #loss func
for t in range(self.nb_iter):
#before: # [nbatch, ndim, nsamples]
#now: # [nbatch, ndim]
exp_noise = exploration * torch.randn_like(prior)/(ndim**0.5)
exp_noise = self.exploration * torch.randn_like(prior)/(ndim**0.5)

# Query deltas for finite difference estimator
##...this step needs to change
Expand All @@ -134,7 +143,7 @@ def L(x): #loss func
L1 = L(x_adv + self.fd_eta * q1) # L(prior + c*noise)
L2 = L(x_adv + self.fd_eta * q2) # L(prior - c*noise)

delta_L = (f1 - f2)/(self.fd_eta * self.exploration) #[nbatch]
delta_L = (L1 - L2)/(self.fd_eta * self.exploration) #[nbatch]

grad_est = delta_L * exp_noise

Expand All @@ -145,7 +154,7 @@ def L(x): #loss func

x_adv = torch.clamp(x_adv, self.clip_min, self.clip_max)

return prior
return x_adv



2 changes: 1 addition & 1 deletion advertorch/attacks/blackbox/nattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def perturb(self, x, y):
n_batch, n_dim = x.shape

#[B]
eps = _check_param(self.eps, x.new_full((x.shape[0],) , 1), 'eps')
eps = _check_param(self.eps, x.new_full((x.shape[0],), 1), 'eps')
#[B, F]
clip_min = _check_param(self.clip_min, x, 'clip_min')
clip_max = _check_param(self.clip_max, x, 'clip_max')
Expand Down

0 comments on commit c6cdc00

Please sign in to comment.