Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update UGBA, its example and changelog #140

Merged
merged 1 commit into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
update UGBA, its example and changelog
  • Loading branch information
ventr1c committed Jun 29, 2023
commit 6c2d04bf1aa39de649f4200be4646fe52318d16f
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ If our work could help your research, please cite:
```

# Changelog
* [06/2023] We have added a backdoor attack [UGBA, WWW'23](https://arxiv.org/abs/2303.01263) to graph package. We can now use UGBA to conduct unnoticeable backdoor attack on large-scale graphs such as ogb-arxiv (see example in [test_ugba.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_ugba.py))!
* [02/2023] DeepRobust 0.2.8 Released. Please try `pip install deeprobust==0.2.8`! We have added a scalable attack [PRBCD, NeurIPS'21](https://arxiv.org/abs/2110.14038) to graph package. We can now use PRBCD to attack large-scale graphs such as ogb-arxiv (see example in [test_prbcd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_prbcd.py))!
* [02/2023] Add a robust model [AirGNN, NeurIPS'21](https://proceedings.neurips.cc/paper/2021/file/50abc3e730e36b387ca8e02c26dc0a22-Paper.pdf) to graph package. Try `python examples/graph/test_airgnn.py`! See details in [test_airgnn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_airgnn.py)
* [11/2022] DeepRobust 0.2.6 Released. Please try `pip install deeprobust==0.2.6`! We have more updates coming. Please stay tuned!
Expand Down
236 changes: 206 additions & 30 deletions deeprobust/graph/targeted_attack/ugba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.utils import degree
from sklearn.cluster import KMeans
from copy import deepcopy
from deeprobust.graph.defense_pyg import GCN, SAGE, GAT
# from deeprobust.graph.defense_pyg import GCN, SAGE, GAT
from deeprobust.graph.targeted_attack import BaseAttack
from deeprobust.graph import utils

Expand Down Expand Up @@ -74,11 +74,26 @@ def __init__(self, data, vs_number,
self.seed = seed
self.debug = debug

# self.train_edge_index, _, self.edge_mask = subgraph(torch.bitwise_not(data.test_mask),data.edge_index,relabel_nodes=False)
# filter out the unlabeled nodes except from training nodes and testing nodes, nonzero() is to get index, flatten is to get 1-d tensor
self.unlabeled_idx = (torch.bitwise_not(data.test_mask)&torch.bitwise_not(data.train_mask)).nonzero().flatten()

self.idx_val = utils.index_to_mask(data.val_mask, size=data.x.shape[0])
def attack(self, target_node, x, y, edge_index, edge_weights = None):
'''
inject the generated trigger to the target node (a single node)

Parameters
----------
target_node: int
the index of target node
x: tensor:
features of nodes
y: tensor:
node labels
edge_index: tensor:
edge index of the graph
edge_weights: tensor:
the weights of edges
'''
idx_target = torch.tensor([target_node])
print(idx_target)
if(edge_weights == None):
Expand All @@ -103,19 +118,22 @@ def get_poisoned_graph(self):
poison_data.train_mask = utils.index_to_mask(idx_bkd_tn, poison_data.x.shape[0])
poison_data.val_mask = utils.index_to_mask(idx_val, poison_data.x.shape[0])
poison_data.test_mask = utils.index_to_mask(idx_test, poison_data.x.shape[0])
# return poison_x, poison_edge_index, poison_edge_weights, poison_labels, idx_bkd_tn
return poison_data

def train_trigger_generator(self, idx_train, edge_index, edge_weights = None, selection_method = 'cluster', **kwargs):
"""
Description
Train the adpative trigger generator

Parameters
----------
target_node : int
target node index to be attacked
idx_attach : tensor
indexs of selected poisoned nodes
idx_train: tensor:
indexs of training nodes
edge_index: tensor:
edge index of the graph
edge_weights: tensor:
the weights of edges
selection method : ['none', 'cluster']
the method to select poisoned nodes
"""
self.idx_train = idx_train
# self.data = data
Expand All @@ -134,6 +152,22 @@ def train_trigger_generator(self, idx_train, edge_index, edge_weights = None, se
return self.trigger_generator, idx_attach

def inject_trigger(self, idx_attach, x, y, edge_index, edge_weights):
"""
Attach the generated triggers with the attachde nodes

Parameters
----------
idx_attach: tensor:
indexs of to-be attached nodes
x: tensor:
features of nodes
y: tensor:
node labels
edge_index: tensor:
edge index of the graph
edge_weights: tensor:
the weights of edges
"""
assert self.trigger_generator, "please first use train_trigger_generator() to train trigger generator"

update_x, update_edge_index,update_edge_weights, update_y = self.trigger_generator.inject_trigger(idx_attach,x,edge_index,edge_weights,y,self.device)
Expand All @@ -143,7 +177,7 @@ def select_idx_attach(self, selection_method, edge_index, edge_weights = None):
if(selection_method == 'none'):
idx_attach = self.obtain_attach_nodes(self.unlabeled_idx,self.size)
elif(selection_method == 'cluster'):
idx_attach = self.cluster_selection(self.data,self.idx_train,self.unlabeled_idx,self.size,edge_index,edge_weights)
idx_attach = self.cluster_selection(self.data,self.idx_train,self.idx_val,self.unlabeled_idx,self.size,edge_index,edge_weights)
idx_attach = torch.LongTensor(idx_attach).to(self.device)
return idx_attach

Expand All @@ -155,13 +189,7 @@ def obtain_attach_nodes(self,node_idxs, size):
rs.shuffle(choice)
return node_idxs[choice[:size]]

def cluster_selection(self,data,idx_train,unlabeled_idx,size,edge_index,edge_weights = None):
# selected_nodes_path = "./selected_nodes/{}/Overall/seed{}/nodes.txt".format(args.dataset,args.seed)
# if(os.path.exists(selected_nodes_path)):
# print(selected_nodes_path)
# idx_attach = np.loadtxt(selected_nodes_path, delimiter=',').astype(int)
# idx_attach = idx_attach[:size]
# return idx_attach
def cluster_selection(self,data,idx_train,idx_val,unlabeled_idx,size,edge_index,edge_weights = None):
gcn_encoder = GCN_Encoder(nfeat=data.x.shape[1],
nhid=32,
nclass= int(data.y.max()+1),
Expand All @@ -174,18 +202,14 @@ def cluster_selection(self,data,idx_train,unlabeled_idx,size,edge_index,edge_wei
t_total = time.time()
# edge_weights = torch.ones([data.edge_index.shape[1]],device=device,dtype=torch.float)
print("Length of training set: {}".format(len(idx_train)))
gcn_encoder.fit(data.x, edge_index, edge_weights, data.y, idx_train, idx_val= None,train_iters=self.train_epochs,verbose=True)
gcn_encoder.fit(data.x, edge_index, edge_weights, data.y, idx_train, idx_val= idx_val,train_iters=self.train_epochs,verbose=True)
print("Training encoder Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# encoder_clean_test_ca = gcn_encoder.test(data.x, data.edge_index, None, data.y,idx_clean_test)
# print("Encoder CA on clean test nodes: {:.4f}".format(encoder_clean_test_ca))
# from sklearn import cluster
seen_node_idx = torch.concat([idx_train,unlabeled_idx])
nclass = np.unique(data.y.cpu().numpy()).shape[0]
encoder_x = gcn_encoder.get_h(data.x, edge_index,edge_weights).clone().detach()

# _, cluster_centers = kmeans(X=encoder_x[seen_node_idx], num_clusters=nclass, distance='euclidean', device=device)
kmeans = KMeans(n_clusters=nclass,random_state=1)
kmeans.fit(encoder_x[seen_node_idx].detach().cpu().numpy())
cluster_centers = kmeans.cluster_centers_
Expand Down Expand Up @@ -586,17 +610,12 @@ def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4,
self.weight_decay = weight_decay

def forward(self, x, edge_index, edge_weight=None):
# for conv in self.convs:
# x = F.relu(conv(x, edge_index,edge_weight))
# x = F.dropout(x, self.dropout, training=self.training)
x = self.body(x, edge_index,edge_weight)
x = self.fc(x)
return F.log_softmax(x,dim=1)
def get_h(self, x, edge_index,edge_weight):
self.eval()
x = self.body(x, edge_index,edge_weight)
# for conv in self.convs:
# x = F.relu(conv(x, edge_index))
return x

def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_val=None, train_iters=200, verbose=False):
Expand Down Expand Up @@ -693,9 +712,6 @@ def test(self, features, edge_index, edge_weight, labels,idx_test):
with torch.no_grad():
output = self.forward(features, edge_index, edge_weight)
acc_test = accuracy(output[idx_test], labels[idx_test])
# print("Test set results:",
# "loss= {:.4f}".format(loss_test.item()),
# "accuracy= {:.4f}".format(acc_test.item()))
return float(acc_test)

def test_with_correct_nodes(self, features, edge_index, edge_weight, labels,idx_test):
Expand Down Expand Up @@ -735,3 +751,163 @@ def forward(self,x, edge_index,edge_weight=None):
x = F.dropout(x, self.dropout, training=self.training)
return x

class GCN(nn.Module):

def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4, layer=2,device=None,layer_norm_first=False,use_ln=False):

super(GCN, self).__init__()

assert device is not None, "Please specify 'device'!"
self.device = device
self.nfeat = nfeat
self.hidden_sizes = [nhid]
self.nclass = nclass
self.convs = nn.ModuleList()
self.convs.append(GCNConv(nfeat, nhid))
self.lns = nn.ModuleList()
self.lns.append(torch.nn.LayerNorm(nfeat))
for _ in range(layer-2):
self.convs.append(GCNConv(nhid,nhid))
self.lns.append(nn.LayerNorm(nhid))
self.lns.append(nn.LayerNorm(nhid))
self.gc2 = GCNConv(nhid, nclass)
self.dropout = dropout
self.lr = lr
self.output = None
self.edge_index = None
self.edge_weight = None
self.features = None
self.weight_decay = weight_decay

self.layer_norm_first = layer_norm_first
self.use_ln = use_ln

def forward(self, x, edge_index, edge_weight=None):
if(self.layer_norm_first):
x = self.lns[0](x)
i=0
for conv in self.convs:
x = F.relu(conv(x, edge_index,edge_weight))
if self.use_ln:
x = self.lns[i+1](x)
i+=1
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, edge_index,edge_weight)
return F.log_softmax(x,dim=1)
def get_h(self, x, edge_index):

for conv in self.convs:
x = F.relu(conv(x, edge_index))

return x

def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_val=None, train_iters=200, verbose=False):
"""Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.
Parameters
----------
features :
node features
adj :
the adjacency matrix. The format could be torch.tensor or scipy matrix
labels :
node labels
idx_train :
node training indices
idx_val :
node validation indices. If not given (None), GCN training process will not adpot early stopping
train_iters : int
number of training epochs
initialize : bool
whether to initialize parameters before training
verbose : bool
whether to show verbose logs
"""

self.edge_index, self.edge_weight = edge_index, edge_weight
self.features = features.to(self.device)
self.labels = labels.to(self.device)

if idx_val is None:
self._train_without_val(self.labels, idx_train, train_iters, verbose)
else:
self._train_with_val(self.labels, idx_train, idx_val, train_iters, verbose)
# torch.cuda.empty_cache()

def _train_without_val(self, labels, idx_train, train_iters, verbose):
self.train()
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
for i in range(train_iters):
optimizer.zero_grad()
output = self.forward(self.features, self.edge_index, self.edge_weight)
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))

self.eval()
output = self.forward(self.features, self.edge_index, self.edge_weight)
self.output = output
# torch.cuda.empty_cache()

def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
if verbose:
print('=== training gcn model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

best_loss_val = 100
best_acc_val = 0

for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(self.features, self.edge_index, self.edge_weight)
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()



self.eval()
output = self.forward(self.features, self.edge_index, self.edge_weight)
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
acc_val = utils.accuracy(output[idx_val], labels[idx_val])

if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
print("acc_val: {:.4f}".format(acc_val))
if acc_val > best_acc_val:
best_acc_val = acc_val
self.output = output
weights = deepcopy(self.state_dict())

if verbose:
print('=== picking the best model according to the performance on validation ===')
self.load_state_dict(weights)
# torch.cuda.empty_cache()


def test(self, features, edge_index, edge_weight, labels,idx_test):
"""Evaluate GCN performance on test set.
Parameters
----------
idx_test :
node testing indices
"""
self.eval()
with torch.no_grad():
output = self.forward(features, edge_index, edge_weight)
acc_test = utils.accuracy(output[idx_test], labels[idx_test])
# torch.cuda.empty_cache()
# print("Test set results:",
# "loss= {:.4f}".format(loss_test.item()),
# "accuracy= {:.4f}".format(acc_test.item()))
return float(acc_test)

def test_with_correct_nodes(self, features, edge_index, edge_weight, labels,idx_test):
self.eval()
output = self.forward(features, edge_index, edge_weight)
correct_nids = (output.argmax(dim=1)[idx_test]==labels[idx_test]).nonzero().flatten() # return a tensor
acc_test = utils.accuracy(output[idx_test], labels[idx_test])
# torch.cuda.empty_cache()
return acc_test,correct_nids
Loading