Skip to content

Commit

Permalink
Support Latest Pytorch (1.11)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhui-zh15 committed Apr 12, 2022
1 parent 96f2689 commit 0f98f5d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions core/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def train_tgt(src_encoder, tgt_encoder, critic,
params.num_epochs,
step + 1,
len_data_loader,
loss_critic.data[0],
loss_tgt.data[0],
acc.data[0]))
loss_critic.item(),
loss_tgt.item(),
acc.item()))

#############################
# 2.4 save model parameters #
Expand Down
8 changes: 4 additions & 4 deletions core/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def train_src(encoder, classifier, data_loader):
params.num_epochs_pre,
step + 1,
len(data_loader),
loss.data[0]))
loss.item()))

# eval model on test set
if ((epoch + 1) % params.eval_step_pre == 0):
Expand All @@ -78,8 +78,8 @@ def eval_src(encoder, classifier, data_loader):
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0
loss = 0.
acc = 0.

# set loss function
criterion = nn.CrossEntropyLoss()
Expand All @@ -90,7 +90,7 @@ def eval_src(encoder, classifier, data_loader):
labels = make_variable(labels)

preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).item()

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
Expand Down
6 changes: 3 additions & 3 deletions core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def eval_tgt(encoder, classifier, data_loader):
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0
loss = 0.
acc = 0.

# set loss function
criterion = nn.CrossEntropyLoss()
Expand All @@ -25,7 +25,7 @@ def eval_tgt(encoder, classifier, data_loader):
labels = make_variable(labels).squeeze_()

preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
loss += criterion(preds, labels).item()

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
Expand Down
4 changes: 2 additions & 2 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
data_root = "data"
dataset_mean_value = 0.5
dataset_std_value = 0.5
dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value)
dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value)
dataset_mean = dataset_mean_value
dataset_std = dataset_std_value
batch_size = 50
image_size = 64

Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def make_variable(tensor, volatile=False):
"""Convert Tensor to Variable."""
if torch.cuda.is_available():
tensor = tensor.cuda()
return Variable(tensor, volatile=volatile)
return tensor


def make_cuda(tensor):
Expand Down

0 comments on commit 0f98f5d

Please sign in to comment.