Skip to content

Commit

Permalink
implement test() for classifing target data.
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 21, 2017
1 parent 332e662 commit 34e487d
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion core/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
"""Test script to classify target data."""

import torch
import torch.nn as nn

from utils import make_variable


def eval_tgt(model_src, model_tgt, data_loader):
"""Evaluation for target encoder by source classifier on target dataset."""
pass
print("=== Evaluating classifier for encoded target domain ===")
model_src.eval()
model_tgt.eval()
loss = 0
acc = 0
criterion = nn.NLLLoss()

for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels).squeeze_()

_, preds = model_tgt(images)
loss += criterion(preds, labels).data[0]

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()

loss /= len(data_loader)
acc /= len(data_loader.dataset)

print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc))

0 comments on commit 34e487d

Please sign in to comment.