Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][PyG] Link prediction example. #7752

Merged
merged 7 commits into from
Aug 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
linting
  • Loading branch information
mfbalin committed Aug 27, 2024
commit 1e4bbf3bed68746caa4d15e0fdc563bda4cba46d
19 changes: 10 additions & 9 deletions examples/graphbolt/pyg/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
└───> Validation and test set evaluation
"""
import argparse
from functools import partial
import time
from functools import partial

import dgl.graphbolt as gb
import torch
Expand All @@ -95,8 +95,8 @@

import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from tqdm import tqdm
from torchmetrics.retrieval import RetrievalMRR
from tqdm import tqdm


class GraphSAGE(torch.nn.Module):
Expand Down Expand Up @@ -232,14 +232,16 @@ def create_dataloader(
# Create and return a DataLoader to handle data loading.
return gb.DataLoader(datapipe, num_workers=args.num_workers)


@torch.compile
def predictions_step(model, h_src, h_dst):
return model.predictor(h_src * h_dst).squeeze()


def compute_predictions(model, node_emb, seeds, device):
"""Compute the predictions for given source and destination nodes.

This function computes the predictions for a set of node pairs, dividing the
This function computes the predictions for a set of node pairs, dividing the
task into batches to handle potentially large graphs.
"""

Expand All @@ -260,6 +262,7 @@ def compute_predictions(model, node_emb, seeds, device):
preds[start:end] = predictions_step(model, h_src, h_dst)
return preds


@torch.no_grad()
def evaluate(model, graph, features, all_nodes_set, valid_set, test_set):
"""Evaluate the model on validation and test sets."""
Expand Down Expand Up @@ -314,9 +317,7 @@ def train_helper(dataloader, model, optimizer, device):
total_samples = 0 # Accumulator for the total number of samples processed
start = time.time()
for step, minibatch in tqdm(enumerate(dataloader), "Training"):
loss, num_samples = train_step(
minibatch, optimizer, model
)
loss, num_samples = train_step(minibatch, optimizer, model)
total_loss += loss * num_samples
total_samples += num_samples
if step + 1 == args.early_stop:
Expand Down Expand Up @@ -486,9 +487,9 @@ def main():

in_channels = features.size("node", None, "feat")[0]
hidden_channels = 256
model = GraphSAGE(
in_channels, hidden_channels, len(args.fanout)
).to(args.device)
model = GraphSAGE(in_channels, hidden_channels, len(args.fanout)).to(
args.device
)
assert len(args.fanout) == len(model.layers)

train(train_dataloader, model, args.device)
Expand Down
Loading