Skip to content

Commit

Permalink
Merge pull request #140 from ventr1c/master
Browse files Browse the repository at this point in the history
update UGBA, its example and changelog
  • Loading branch information
ChandlerBang authored Jun 29, 2023
2 parents 5b97b44 + 6c2d04b commit d25d95b
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 76 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ If our work could help your research, please cite:
```

# Changelog
* [06/2023] We have added a backdoor attack [UGBA, WWW'23](https://arxiv.org/abs/2303.01263) to graph package. We can now use UGBA to conduct unnoticeable backdoor attack on large-scale graphs such as ogb-arxiv (see example in [test_ugba.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_ugba.py))!
* [02/2023] DeepRobust 0.2.8 Released. Please try `pip install deeprobust==0.2.8`! We have added a scalable attack [PRBCD, NeurIPS'21](https://arxiv.org/abs/2110.14038) to graph package. We can now use PRBCD to attack large-scale graphs such as ogb-arxiv (see example in [test_prbcd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_prbcd.py))!
* [02/2023] Add a robust model [AirGNN, NeurIPS'21](https://proceedings.neurips.cc/paper/2021/file/50abc3e730e36b387ca8e02c26dc0a22-Paper.pdf) to graph package. Try `python examples/graph/test_airgnn.py`! See details in [test_airgnn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_airgnn.py)
* [11/2022] DeepRobust 0.2.6 Released. Please try `pip install deeprobust==0.2.6`! We have more updates coming. Please stay tuned!
Expand Down
236 changes: 206 additions & 30 deletions deeprobust/graph/targeted_attack/ugba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.utils import degree
from sklearn.cluster import KMeans
from copy import deepcopy
from deeprobust.graph.defense_pyg import GCN, SAGE, GAT
# from deeprobust.graph.defense_pyg import GCN, SAGE, GAT
from deeprobust.graph.targeted_attack import BaseAttack
from deeprobust.graph import utils

Expand Down Expand Up @@ -74,11 +74,26 @@ def __init__(self, data, vs_number,
self.seed = seed
self.debug = debug

# self.train_edge_index, _, self.edge_mask = subgraph(torch.bitwise_not(data.test_mask),data.edge_index,relabel_nodes=False)
# filter out the unlabeled nodes except from training nodes and testing nodes, nonzero() is to get index, flatten is to get 1-d tensor
self.unlabeled_idx = (torch.bitwise_not(data.test_mask)&torch.bitwise_not(data.train_mask)).nonzero().flatten()

self.idx_val = utils.index_to_mask(data.val_mask, size=data.x.shape[0])
def attack(self, target_node, x, y, edge_index, edge_weights = None):
'''
inject the generated trigger to the target node (a single node)
Parameters
----------
target_node: int
the index of target node
x: tensor:
features of nodes
y: tensor:
node labels
edge_index: tensor:
edge index of the graph
edge_weights: tensor:
the weights of edges
'''
idx_target = torch.tensor([target_node])
print(idx_target)
if(edge_weights == None):
Expand All @@ -103,19 +118,22 @@ def get_poisoned_graph(self):
poison_data.train_mask = utils.index_to_mask(idx_bkd_tn, poison_data.x.shape[0])
poison_data.val_mask = utils.index_to_mask(idx_val, poison_data.x.shape[0])
poison_data.test_mask = utils.index_to_mask(idx_test, poison_data.x.shape[0])
# return poison_x, poison_edge_index, poison_edge_weights, poison_labels, idx_bkd_tn
return poison_data

def train_trigger_generator(self, idx_train, edge_index, edge_weights = None, selection_method = 'cluster', **kwargs):
"""
Description
Train the adpative trigger generator
Parameters
----------
target_node : int
target node index to be attacked
idx_attach : tensor
indexs of selected poisoned nodes
idx_train: tensor:
indexs of training nodes
edge_index: tensor:
edge index of the graph
edge_weights: tensor:
the weights of edges
selection method : ['none', 'cluster']
the method to select poisoned nodes
"""
self.idx_train = idx_train
# self.data = data
Expand All @@ -134,6 +152,22 @@ def train_trigger_generator(self, idx_train, edge_index, edge_weights = None, se
return self.trigger_generator, idx_attach

def inject_trigger(self, idx_attach, x, y, edge_index, edge_weights):
"""
Attach the generated triggers with the attachde nodes
Parameters
----------
idx_attach: tensor:
indexs of to-be attached nodes
x: tensor:
features of nodes
y: tensor:
node labels
edge_index: tensor:
edge index of the graph
edge_weights: tensor:
the weights of edges
"""
assert self.trigger_generator, "please first use train_trigger_generator() to train trigger generator"

update_x, update_edge_index,update_edge_weights, update_y = self.trigger_generator.inject_trigger(idx_attach,x,edge_index,edge_weights,y,self.device)
Expand All @@ -143,7 +177,7 @@ def select_idx_attach(self, selection_method, edge_index, edge_weights = None):
if(selection_method == 'none'):
idx_attach = self.obtain_attach_nodes(self.unlabeled_idx,self.size)
elif(selection_method == 'cluster'):
idx_attach = self.cluster_selection(self.data,self.idx_train,self.unlabeled_idx,self.size,edge_index,edge_weights)
idx_attach = self.cluster_selection(self.data,self.idx_train,self.idx_val,self.unlabeled_idx,self.size,edge_index,edge_weights)
idx_attach = torch.LongTensor(idx_attach).to(self.device)
return idx_attach

Expand All @@ -155,13 +189,7 @@ def obtain_attach_nodes(self,node_idxs, size):
rs.shuffle(choice)
return node_idxs[choice[:size]]

def cluster_selection(self,data,idx_train,unlabeled_idx,size,edge_index,edge_weights = None):
# selected_nodes_path = "./selected_nodes/{}/Overall/seed{}/nodes.txt".format(args.dataset,args.seed)
# if(os.path.exists(selected_nodes_path)):
# print(selected_nodes_path)
# idx_attach = np.loadtxt(selected_nodes_path, delimiter=',').astype(int)
# idx_attach = idx_attach[:size]
# return idx_attach
def cluster_selection(self,data,idx_train,idx_val,unlabeled_idx,size,edge_index,edge_weights = None):
gcn_encoder = GCN_Encoder(nfeat=data.x.shape[1],
nhid=32,
nclass= int(data.y.max()+1),
Expand All @@ -174,18 +202,14 @@ def cluster_selection(self,data,idx_train,unlabeled_idx,size,edge_index,edge_wei
t_total = time.time()
# edge_weights = torch.ones([data.edge_index.shape[1]],device=device,dtype=torch.float)
print("Length of training set: {}".format(len(idx_train)))
gcn_encoder.fit(data.x, edge_index, edge_weights, data.y, idx_train, idx_val= None,train_iters=self.train_epochs,verbose=True)
gcn_encoder.fit(data.x, edge_index, edge_weights, data.y, idx_train, idx_val= idx_val,train_iters=self.train_epochs,verbose=True)
print("Training encoder Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# encoder_clean_test_ca = gcn_encoder.test(data.x, data.edge_index, None, data.y,idx_clean_test)
# print("Encoder CA on clean test nodes: {:.4f}".format(encoder_clean_test_ca))
# from sklearn import cluster
seen_node_idx = torch.concat([idx_train,unlabeled_idx])
nclass = np.unique(data.y.cpu().numpy()).shape[0]
encoder_x = gcn_encoder.get_h(data.x, edge_index,edge_weights).clone().detach()

# _, cluster_centers = kmeans(X=encoder_x[seen_node_idx], num_clusters=nclass, distance='euclidean', device=device)
kmeans = KMeans(n_clusters=nclass,random_state=1)
kmeans.fit(encoder_x[seen_node_idx].detach().cpu().numpy())
cluster_centers = kmeans.cluster_centers_
Expand Down Expand Up @@ -586,17 +610,12 @@ def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4,
self.weight_decay = weight_decay

def forward(self, x, edge_index, edge_weight=None):
# for conv in self.convs:
# x = F.relu(conv(x, edge_index,edge_weight))
# x = F.dropout(x, self.dropout, training=self.training)
x = self.body(x, edge_index,edge_weight)
x = self.fc(x)
return F.log_softmax(x,dim=1)
def get_h(self, x, edge_index,edge_weight):
self.eval()
x = self.body(x, edge_index,edge_weight)
# for conv in self.convs:
# x = F.relu(conv(x, edge_index))
return x

def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_val=None, train_iters=200, verbose=False):
Expand Down Expand Up @@ -693,9 +712,6 @@ def test(self, features, edge_index, edge_weight, labels,idx_test):
with torch.no_grad():
output = self.forward(features, edge_index, edge_weight)
acc_test = accuracy(output[idx_test], labels[idx_test])
# print("Test set results:",
# "loss= {:.4f}".format(loss_test.item()),
# "accuracy= {:.4f}".format(acc_test.item()))
return float(acc_test)

def test_with_correct_nodes(self, features, edge_index, edge_weight, labels,idx_test):
Expand Down Expand Up @@ -735,3 +751,163 @@ def forward(self,x, edge_index,edge_weight=None):
x = F.dropout(x, self.dropout, training=self.training)
return x

class GCN(nn.Module):

def __init__(self, nfeat, nhid, nclass, dropout=0.5, lr=0.01, weight_decay=5e-4, layer=2,device=None,layer_norm_first=False,use_ln=False):

super(GCN, self).__init__()

assert device is not None, "Please specify 'device'!"
self.device = device
self.nfeat = nfeat
self.hidden_sizes = [nhid]
self.nclass = nclass
self.convs = nn.ModuleList()
self.convs.append(GCNConv(nfeat, nhid))
self.lns = nn.ModuleList()
self.lns.append(torch.nn.LayerNorm(nfeat))
for _ in range(layer-2):
self.convs.append(GCNConv(nhid,nhid))
self.lns.append(nn.LayerNorm(nhid))
self.lns.append(nn.LayerNorm(nhid))
self.gc2 = GCNConv(nhid, nclass)
self.dropout = dropout
self.lr = lr
self.output = None
self.edge_index = None
self.edge_weight = None
self.features = None
self.weight_decay = weight_decay

self.layer_norm_first = layer_norm_first
self.use_ln = use_ln

def forward(self, x, edge_index, edge_weight=None):
if(self.layer_norm_first):
x = self.lns[0](x)
i=0
for conv in self.convs:
x = F.relu(conv(x, edge_index,edge_weight))
if self.use_ln:
x = self.lns[i+1](x)
i+=1
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, edge_index,edge_weight)
return F.log_softmax(x,dim=1)
def get_h(self, x, edge_index):

for conv in self.convs:
x = F.relu(conv(x, edge_index))

return x

def fit(self, features, edge_index, edge_weight, labels, idx_train, idx_val=None, train_iters=200, verbose=False):
"""Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.
Parameters
----------
features :
node features
adj :
the adjacency matrix. The format could be torch.tensor or scipy matrix
labels :
node labels
idx_train :
node training indices
idx_val :
node validation indices. If not given (None), GCN training process will not adpot early stopping
train_iters : int
number of training epochs
initialize : bool
whether to initialize parameters before training
verbose : bool
whether to show verbose logs
"""

self.edge_index, self.edge_weight = edge_index, edge_weight
self.features = features.to(self.device)
self.labels = labels.to(self.device)

if idx_val is None:
self._train_without_val(self.labels, idx_train, train_iters, verbose)
else:
self._train_with_val(self.labels, idx_train, idx_val, train_iters, verbose)
# torch.cuda.empty_cache()

def _train_without_val(self, labels, idx_train, train_iters, verbose):
self.train()
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
for i in range(train_iters):
optimizer.zero_grad()
output = self.forward(self.features, self.edge_index, self.edge_weight)
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))

self.eval()
output = self.forward(self.features, self.edge_index, self.edge_weight)
self.output = output
# torch.cuda.empty_cache()

def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
if verbose:
print('=== training gcn model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

best_loss_val = 100
best_acc_val = 0

for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(self.features, self.edge_index, self.edge_weight)
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()



self.eval()
output = self.forward(self.features, self.edge_index, self.edge_weight)
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
acc_val = utils.accuracy(output[idx_val], labels[idx_val])

if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
print("acc_val: {:.4f}".format(acc_val))
if acc_val > best_acc_val:
best_acc_val = acc_val
self.output = output
weights = deepcopy(self.state_dict())

if verbose:
print('=== picking the best model according to the performance on validation ===')
self.load_state_dict(weights)
# torch.cuda.empty_cache()


def test(self, features, edge_index, edge_weight, labels,idx_test):
"""Evaluate GCN performance on test set.
Parameters
----------
idx_test :
node testing indices
"""
self.eval()
with torch.no_grad():
output = self.forward(features, edge_index, edge_weight)
acc_test = utils.accuracy(output[idx_test], labels[idx_test])
# torch.cuda.empty_cache()
# print("Test set results:",
# "loss= {:.4f}".format(loss_test.item()),
# "accuracy= {:.4f}".format(acc_test.item()))
return float(acc_test)

def test_with_correct_nodes(self, features, edge_index, edge_weight, labels,idx_test):
self.eval()
output = self.forward(features, edge_index, edge_weight)
correct_nids = (output.argmax(dim=1)[idx_test]==labels[idx_test]).nonzero().flatten() # return a tensor
acc_test = utils.accuracy(output[idx_test], labels[idx_test])
# torch.cuda.empty_cache()
return acc_test,correct_nids
Loading

0 comments on commit d25d95b

Please sign in to comment.