Skip to content

Commit

Permalink
pipeline issue debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
Alpay Sedat Durukan committed Apr 19, 2022
1 parent 4eae85d commit d608f46
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test(data):
return accs


def train(data):
def train(model, data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
Expand All @@ -101,15 +101,15 @@ def train(data):
data_list = get_data(graph_data, "train_data")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("models/gat_300_2")
print(f"model: {model}")
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
accuracy_graph = {}
for i, data in zip(range(len(data_list)), data_list):
data = data.to(device)
accuracy_epoch = []
for epoch in range(1, 200):
loss = train(data)
loss = train(model, data)
train_acc = test(data)
print(f"epoch: {epoch} with type {type(epoch)}")
if epoch == 1:
accuracy_epoch.append(train_acc[0])
elif epoch == 199:
Expand Down

0 comments on commit d608f46

Please sign in to comment.