Skip to content

Commit

Permalink
Revert "负采样根据度进行采样"
Browse files Browse the repository at this point in the history
This reverts commit 53405cd
  • Loading branch information
farkguidao committed Jan 26, 2022
1 parent 53405cd commit a699240
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions dataloader/link_pre_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def read_data(self):
# mask除去自环
mask = self.edge_type>0
self.train_dataset = TensorDataset(self.edge_index.T[mask],self.edge_type[mask],self.edge_id[mask])
adj = torch.sparse_coo_tensor(self.edge_index, torch.ones(self.E)).coalesce()
self.degree = torch.sparse.sum(adj, 0).to_dense()

def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(self.train_dataset,self.batch_size,shuffle=True,num_workers=self.num_workers,drop_last=True)

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_trainer_model_dataloader_from_yaml(yaml_path):
settings = dict(yaml.load(f,yaml.FullLoader))

dl = LinkPredictionDataloader(**settings['data'])
model = LinkPredictionTask(dl.edge_index,dl.edge_type,dl.feature_data,dl.N,dl.degree, **settings['model'])
model = LinkPredictionTask(dl.edge_index,dl.edge_type,dl.feature_data,dl.N, **settings['model'])
checkpoint_callback = pl.callbacks.ModelCheckpoint(**settings['callback'])
trainer = pl.Trainer(callbacks=[checkpoint_callback], **settings['train'])
return trainer,model,dl
Expand Down
2 changes: 1 addition & 1 deletion settings/ama_settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
# use_feature,feature_dim,d_model,type_num, L,neg_num,dropout,lr,wd
use_feature: True
feature_dim: 1156
d_model: 64
d_model: 128
type_num: 2
L: 3
use_gradient_checkpointing: False
Expand Down

0 comments on commit a699240

Please sign in to comment.