Skip to content

Commit

Permalink
负采样根据度进行采样
Browse files Browse the repository at this point in the history
  • Loading branch information
farkguidao committed Jan 26, 2022
1 parent 247cdea commit 53405cd
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
3 changes: 2 additions & 1 deletion dataloader/link_pre_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def read_data(self):
# mask除去自环
mask = self.edge_type>0
self.train_dataset = TensorDataset(self.edge_index.T[mask],self.edge_type[mask],self.edge_id[mask])

adj = torch.sparse_coo_tensor(self.edge_index, torch.ones(self.E)).coalesce()
self.degree = torch.sparse.sum(adj, 0).to_dense()
def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(self.train_dataset,self.batch_size,shuffle=True,num_workers=self.num_workers,drop_last=True)

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_trainer_model_dataloader_from_yaml(yaml_path):
settings = dict(yaml.load(f,yaml.FullLoader))

dl = LinkPredictionDataloader(**settings['data'])
model = LinkPredictionTask(dl.edge_index,dl.edge_type,dl.feature_data,dl.N, **settings['model'])
model = LinkPredictionTask(dl.edge_index,dl.edge_type,dl.feature_data,dl.N,dl.degree, **settings['model'])
checkpoint_callback = pl.callbacks.ModelCheckpoint(**settings['callback'])
trainer = pl.Trainer(callbacks=[checkpoint_callback], **settings['train'])
return trainer,model,dl
Expand Down
14 changes: 8 additions & 6 deletions models/LinkPreTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


class LinkPredictionTask(pl.LightningModule):
def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
def __init__(self,edge_index,edge_type,feature,N,degree,use_feature,feature_dim,d_model,type_num, L,use_gradient_checkpointing,neg_num,dropout,lr,wd):
super(LinkPredictionTask, self).__init__()
# 工程类组件
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N'])
self.save_hyperparameters(ignore=['edge_index','edge_type','feature','N','degree'])
self.register_buffer('edge_index',edge_index)
self.register_buffer('edge_type',edge_type)
self.register_buffer('edge_feature',torch.eye(type_num+1))
Expand All @@ -25,7 +25,7 @@ def __init__(self,edge_index,edge_type,feature,N,use_feature,feature_dim,d_model
self.feature = nn.Parameter(torch.randn(N,d_model))
# nn.init.xavier_uniform_(self.feature)
self.loss2 = nn.CrossEntropyLoss()
self.loss1 = NCELoss(N)
self.loss1 = NCELoss(N,degree)
self.val_best_auc = 0
self.val_best_aupr = 0
self.test_best_auc = 0
Expand Down Expand Up @@ -108,13 +108,15 @@ def configure_optimizers(self):
return optimizer

class NCELoss(nn.Module):
def __init__(self,N):
def __init__(self,N,degree):
super(NCELoss, self).__init__()
self.N = N
self.register_buffer('degree',degree)
self.bce=nn.BCEWithLogitsLoss()
def forward(self,inputs,weights,labels,neg_num):
neg_batch = torch.randint(0, self.N, (neg_num*inputs.shape[0],),
dtype=torch.long,device=inputs.device)
# neg_batch = torch.randint(0, self.N, (neg_num*inputs.shape[0],),
# dtype=torch.long,device=inputs.device)
neg_batch = torch.multinomial(self.degree,neg_num*inputs.shape[0],True)
target = weights[torch.cat([labels,neg_batch],dim=0)]
label = torch.zeros(target.shape[0],device=inputs.device)
label[:labels.shape[0]]=1
Expand Down
2 changes: 1 addition & 1 deletion settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: True
feature_dim: 1156
d_model: 128
d_model: 64
type_num: 2
L: 3
use_gradient_checkpointing: False
Expand Down

0 comments on commit 53405cd

Please sign in to comment.