-
Notifications
You must be signed in to change notification settings - Fork 140
/
Copy pathtest.py
36 lines (25 loc) · 938 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Test script to classify target data."""
import torch
import torch.nn as nn
from utils import make_variable
def eval_tgt(encoder, classifier, data_loader):
"""Evaluation for target encoder by source classifier on target dataset."""
# set eval state for Dropout and BN layers
encoder.eval()
classifier.eval()
# init loss and accuracy
loss = 0
acc = 0
# set loss function
criterion = nn.CrossEntropyLoss()
# evaluate network
for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels).squeeze_()
preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]
pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()
loss /= len(data_loader)
acc /= len(data_loader.dataset)
print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc))