You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What I have recently done is modify the example to make it possible to successfully create a TorchScript of the model. This because I am trying to lowering down the model to torch-mlir. When trying to do it I encounter an error which, as stated by torch-mlir devs, means that my model has a tuple like x=(0,0). They suggested me to try to change this tuple with a list, like x=[0,0].
Unfortunately, I am new into this and I have not been able to spot the problem. I leave the torch-mlir error here for completeness.
Can you please help me to spot this tuple in order to make your models compatible with torch-mlir ?
I leave here the files of your model I am using. Some changes has been done only to make it compatible with TorchScript (and, for simplicity, only the code for the GIN model has been preserved).
Thank you in advance for your help.
main.py
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.# See https://llvm.org/LICENSE.txt for license information.# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception# Also available under a BSD-style license. See LICENSE.importsysfromPILimportImageimportrequestsimporttorchfromtorchvisionimporttransformsfromtqdmimporttqdmfromtorch_geometric.loaderimportDataLoaderfromgnnimportGNNimporttorch.optimasoptimfromogb.graphproppredimportPygGraphPropPredDataset, Evaluatorimporttorch_mlirfromtorch_mlir_e2e_test.linalg_on_tensors_backendsimportrefbackenddeftrain(model, device, loader, optimizer, task_type):
model.train()
forstep, batchinenumerate(tqdm(loader, desc="Iteration")):
batch=batch.to(device)
ifbatch.x.shape[0] ==1orbatch.batch[-1] ==0:
passelse:
pred=model(batch)
optimizer.zero_grad()
## ignore nan targets (unlabeled) when computing training loss.is_labeled=batch.y==batch.yif"classification"intask_type:
loss=cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
else:
loss=reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
loss.backward()
optimizer.step()
defeval(model, device, loader, evaluator):
model.eval()
y_true= []
y_pred= []
forstep, batchinenumerate(tqdm(loader, desc="Iteration")):
batch=batch.to(device)
ifbatch.x.shape[0] ==1:
passelse:
x, edge_index, edge_attr, batch_f=batch.x, batch.edge_index, batch.edge_attr, batch.batchwithtorch.no_grad():
pred=model(x, edge_index, edge_attr, batch_f)
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true=torch.cat(y_true, dim=0).numpy()
y_pred=torch.cat(y_pred, dim=0).numpy()
input_dict= {"y_true": y_true, "y_pred": y_pred}
returnevaluator.eval(input_dict)
defpredictions(torch_model, jit_model):
pytorch_prediction=eval(torch_model, device, test_loader, evaluator)
print("PyTorch prediction")
print(pytorch_prediction)
mlir_prediction=eval(jit_model, device, test_loader, evaluator)
print("torch-mlir prediction")
print(mlir_prediction)
cls_criterion=torch.nn.BCEWithLogitsLoss()
reg_criterion=torch.nn.MSELoss()
### automatic data loading and splittingdataset=PygGraphPropPredDataset(name='ogbg-molhiv')
split_idx=dataset.get_idx_split()
### automatic evaluator. takes dataset name as inputevaluator=Evaluator("ogbg-molhiv")
train_loader=DataLoader(dataset[split_idx["train"]], batch_size=1, shuffle=True,
num_workers=0)
valid_loader=DataLoader(dataset[split_idx["valid"]], batch_size=1, shuffle=False,
num_workers=0)
test_loader=DataLoader(dataset[split_idx["test"]], batch_size=1, shuffle=False,
num_workers=0)
gin=GNN(gnn_type='gin', num_tasks=dataset.num_tasks, num_layer=5, emb_dim=300,
drop_ratio=0.5).to("cpu")
optimizer=optim.Adam(gin.parameters(), lr=0.001)
device=torch.device("cpu")
train(gin, device, train_loader, optimizer, dataset.task_type)
eval(gin, device, valid_loader, evaluator)
gin.eval()
forstep, batchinenumerate(tqdm(test_loader, desc="Iteration")):
batch=batch.to(device)
x, edge_index, edge_attr, batch_f=batch.x, batch.edge_index, batch.edge_attr, batch.batchmodule=torch_mlir.compile(gin, (x, edge_index, edge_attr, batch_f), output_type="linalg-on-tensors")
breakbackend=refbackend.RefBackendLinalgOnTensorsBackend()
compiled=backend.compile(module)
jit_module=backend.load(compiled)
predictions(gin, jit_module)
gnn.py
importtorchfromtorch_geometric.nnimportMessagePassingfromtorch_geometric.nnimportglobal_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Setimporttorch.nn.functionalasFfromtorch_geometric.nn.initsimportuniformfromconvimportGNN_nodefromtorch_scatterimportscatter_meanimporttimeclassGNN(torch.nn.Module):
def__init__(self, num_tasks=10, num_layer=5, emb_dim=300,
gnn_type='gin', residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
""" num_tasks (int): number of labels to be predicted virtual_node (bool): whether to add virtual node or not """super(GNN, self).__init__()
self.num_layer=num_layerself.drop_ratio=drop_ratioself.JK=JKself.emb_dim=emb_dimself.num_tasks=num_tasksself.graph_pooling=graph_pooling### GNN to generate node embeddingsself.gnn_node=GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
gnn_type=gnn_type)
### Pooling function to generate whole-graph embeddingsifself.graph_pooling=="sum":
self.pool=global_add_poolelifself.graph_pooling=="mean":
self.pool=global_mean_poolelifself.graph_pooling=="max":
self.pool=global_max_poolelifself.graph_pooling=="attention":
self.pool=GlobalAttention(
gate_nn=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
elifself.graph_pooling=="set2set":
self.pool=Set2Set(emb_dim, processing_steps=2)
else:
raiseValueError("Invalid grcd ..aph pooling type.")
ifgraph_pooling=="set2set":
self.graph_pred_linear=torch.nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear=torch.nn.Linear(self.emb_dim, self.num_tasks)
defforward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, batch) ->torch.Tensor:
h_node=self.gnn_node(x, edge_index, edge_attr)
h_graph=self.pool(h_node, batch)
returnself.graph_pred_linear(h_graph)
if__name__=='__main__':
GNN(num_tasks=10)
conv.py
importtorchfromtorch_geometric.nnimportMessagePassingimporttorch.nn.functionalasFfromtorch_geometric.nnimportglobal_mean_pool, global_add_poolfromogb.graphproppred.mol_encoderimportAtomEncoder, BondEncoderfromtorch_geometric.utilsimportdegreeimportmathimporttime### GIN convolution along the graph structureclassGINConv(MessagePassing):
propagate_type= {'x': torch.Tensor, 'edge_attr': torch.Tensor}
def__init__(self, emb_dim):
''' emb_dim (int): node embedding dimensionality '''super(GINConv, self).__init__(aggr="add")
self.mlp=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
self.eps=torch.nn.Parameter(torch.Tensor([0]))
self.bond_encoder=BondEncoder(emb_dim=emb_dim)
defforward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) ->torch.Tensor:
edge_embedding=self.bond_encoder(edge_attr)
assertisinstance(edge_embedding, torch.Tensor)
returnself.mlp((1+self.eps) *x+self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))
defmessage(self, x_j, edge_attr):
returnF.relu(x_j+edge_attr)
defupdate(self, aggr_out):
returnaggr_out### GNN to generate node embeddingclassGNN_node(torch.nn.Module):
""" Output: node representations """def__init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'):
''' emb_dim (int): node embedding dimensionality num_layer (int): number of GNN message passing layers '''super(GNN_node, self).__init__()
self.num_layer=num_layerself.drop_ratio=drop_ratioself.JK=JK### add residual connection or notself.residual=residualifself.num_layer<2:
raiseValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder=AtomEncoder(emb_dim)
###List of GNNsself.convs=torch.nn.ModuleList()
self.batch_norms=torch.nn.ModuleList()
for_inrange(num_layer):
self.convs.append(GINConv(emb_dim).jittable())
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
defforward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) ->torch.Tensor:
### computing input node embeddingh_list= [self.atom_encoder(x)]
forlayer, (conv, norm) inenumerate(zip(self.convs, self.batch_norms)):
assertisinstance(h_list[layer], torch.Tensor)
h=conv(h_list[layer], edge_index, edge_attr)
h=norm(h)
iflayer==self.num_layer-1:
# remove relu for the last layerh=F.dropout(h, self.drop_ratio, training=self.training)
else:
h=F.dropout(F.relu(h), self.drop_ratio, training=self.training)
ifself.residual:
h+=h_list[layer]
h_list.append(h)
### Different implementations of Jk-concat#if self.JK == "last":node_representation=h_list[-1]
#elif self.JK == "sum":# node_representation = 0# for layer in range(self.num_layer + 1):# node_representation += h_list[layer]returnnode_representationif__name__=="__main__":
pass
Hi, I am opening this issue to ask a question. I am trying to use one of your examples about graph classification (https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol).
What I have recently done is modify the example to make it possible to successfully create a TorchScript of the model. This because I am trying to lowering down the model to torch-mlir. When trying to do it I encounter an error which, as stated by torch-mlir devs, means that my model has a tuple like x=(0,0). They suggested me to try to change this tuple with a list, like x=[0,0].
Unfortunately, I am new into this and I have not been able to spot the problem. I leave the torch-mlir error here for completeness.
Can you please help me to spot this tuple in order to make your models compatible with torch-mlir ?
I leave here the files of your model I am using. Some changes has been done only to make it compatible with TorchScript (and, for simplicity, only the code for the GIN model has been preserved).
Thank you in advance for your help.
main.py
gnn.py
conv.py
mol_encoder.py
The text was updated successfully, but these errors were encountered: