Skip to content

Commit

Permalink
- data[0] not supported on Pytorch V.2.4.1 replaced by item()
Browse files Browse the repository at this point in the history
  • Loading branch information
franck-armand committed Oct 11, 2024
1 parent 96f2689 commit 19dd90c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions core/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def train_tgt(src_encoder, tgt_encoder, critic,
params.num_epochs,
step + 1,
len_data_loader,
loss_critic.data[0],
loss_tgt.data[0],
acc.data[0]))
loss_critic.item(),
loss_tgt.item(),
acc.item()))

#############################
# 2.4 save model parameters #
Expand Down
8 changes: 4 additions & 4 deletions core/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def train_src(encoder, classifier, data_loader):
params.num_epochs_pre,
step + 1,
len(data_loader),
loss.data[0]))
loss.item()))

# eval model on test set
if ((epoch + 1) % params.eval_step_pre == 0):
Expand All @@ -78,8 +78,8 @@ def eval_src(encoder, classifier, data_loader):
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0
loss = 0.0
acc = 0.0

# set loss function
criterion = nn.CrossEntropyLoss()
Expand All @@ -90,7 +90,7 @@ def eval_src(encoder, classifier, data_loader):
labels = make_variable(labels)

preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).item()

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
Expand Down
6 changes: 3 additions & 3 deletions core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def eval_tgt(encoder, classifier, data_loader):
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0
loss = 0.0
acc = 0.0

# set loss function
criterion = nn.CrossEntropyLoss()
Expand All @@ -25,7 +25,7 @@ def eval_tgt(encoder, classifier, data_loader):
labels = make_variable(labels).squeeze_()

preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).item()

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
Expand Down

0 comments on commit 19dd90c

Please sign in to comment.