Skip to content

Commit

Permalink
add training and evaluation for source model
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 19, 2017
1 parent 53b997a commit c98bf21
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 60 deletions.
4 changes: 2 additions & 2 deletions core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .pretrain import train_src
from .pretrain import eval_src, train_src

__all__ = (train_src)
__all__ = (eval_src, train_src)
44 changes: 36 additions & 8 deletions core/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@
from utils import make_variable, save_model


def train_src(net, data_loader):
def train_src(model, data_loader):
"""Train classifier for source domain."""
print("Classifier for source domain:")
print(net)
print("=== Training classifier for source domain ===")
print(model)

optimizer = optim.Adam(net.parameters(),
optimizer = optim.Adam(model.parameters(),
lr=params.c_learning_rate,
betas=(params.beta1, params.beta2))
criterion = nn.NLLLoss()

for epoch in range(params.num_epochs_pre):
net.train()
model.train()
for step, (images, labels) in enumerate(data_loader):
images = make_variable(images)
labels = make_variable(labels.squeeze_())

optimizer.zero_grad()

_, preds = net(images)
_, preds = model(images)
loss = criterion(preds, labels)

loss.backward()
Expand All @@ -43,6 +43,34 @@ def train_src(net, data_loader):
loss.data[0]))

if ((epoch + 1) % params.save_step == 0):
save_model(net, "classifier_src-{}.pt".format(epoch + 1))
save_model(model, "classifier_src-{}.pt".format(epoch + 1))

save_model(net, "classifier_src-final.pt")
save_model(model, "classifier_src-final.pt")

return model


def eval_src(model, data_loader):
"""Evaluate classifier for source domain."""
print("=== Evaluating classifier for source domain ===")
print(model)

model.eval()
loss = 0
acc = 0
criterion = nn.NLLLoss()

for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels)

_, preds = model(images)
loss += criterion(preds, labels).data[0]

pred_cls = preds.data.max(1)[1]
acc += pred_cls.eq(labels.data).cpu().sum()

loss /= len(data_loader)
acc /= len(data_loader.dataset)

print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc))
34 changes: 18 additions & 16 deletions datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@

import params

# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])

# dataset and data loader
mnist_dataset = datasets.MNIST(root=params.data_root,
transform=pre_process,
download=True)

mnist_data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset,
batch_size=params.batch_size,
shuffle=True)


def get_mnist():
def get_mnist(train):
"""Get MNIST dataset loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])

# dataset and data loader
mnist_dataset = datasets.MNIST(root=params.data_root,
train=train,
transform=pre_process,
download=True)

mnist_data_loader = torch.utils.data.DataLoader(
dataset=mnist_dataset,
batch_size=params.batch_size,
shuffle=True)

return mnist_data_loader
36 changes: 18 additions & 18 deletions datasets/usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,23 @@ def load_samples(self):
return images, labels


# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])

# dataset and data loader
usps_dataset = USPS(root=params.data_root,
train=True,
transform=pre_process,
download=True)

usps_data_loader = torch.utils.data.DataLoader(dataset=usps_dataset,
batch_size=params.batch_size,
shuffle=True)


def get_usps():
def get_usps(train):
"""Get USPS dataset loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])

# dataset and data loader
usps_dataset = USPS(root=params.data_root,
train=train,
transform=pre_process,
download=True)

usps_data_loader = torch.utils.data.DataLoader(
dataset=usps_dataset,
batch_size=params.batch_size,
shuffle=True)

return usps_data_loader
27 changes: 15 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main script for ADDA."""

import params
from core import train_src
from core import eval_src, train_src
from models import Classifier, Discriminator
from utils import get_data_loader, init_model, init_random_seed

Expand All @@ -11,22 +11,25 @@

# load dataset
src_data_loader = get_data_loader(params.src_dataset)
src_data_loader_eval = get_data_loader(params.src_dataset, train=False)
tgt_data_loader = get_data_loader(params.tgt_dataset)

# load models
C_src = init_model(net=Classifier(num_channels=params.num_channels,
conv_dims=params.c_conv_dims,
num_classes=params.num_classes,
fc_dims=params.c_fc_dims),
restore=params.src_model_restore)
C_tgt = init_model(net=Classifier(num_channels=params.num_channels,
conv_dims=params.c_conv_dims,
num_classes=params.num_classes,
fc_dims=params.c_fc_dims),
restore=params.tgt_model_restore)
model_src = init_model(net=Classifier(num_channels=params.num_channels,
conv_dims=params.c_conv_dims,
num_classes=params.num_classes,
fc_dims=params.c_fc_dims),
restore=params.src_model_restore)
model_tgt = init_model(net=Classifier(num_channels=params.num_channels,
conv_dims=params.c_conv_dims,
num_classes=params.num_classes,
fc_dims=params.c_fc_dims),
restore=params.tgt_model_restore)
D = init_model(Discriminator(input_dims=params.d_input_dims,
hidden_dims=params.d_hidden_dims,
output_dims=params.d_output_dims),
restore=params.d_model_restore)

train_src(C_src, src_data_loader)
if not model_src.restored:
model_src = train_src(model_src, src_data_loader)
eval_src(model_src, src_data_loader_eval)
1 change: 1 addition & 0 deletions models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, num_channels, conv_dims, num_classes, fc_dims):
self.conv_dims = conv_dims
self.num_classes = num_classes
self.fc_dims = fc_dims
self.restored = False

self.encoder = nn.Sequential(
# 1st conv layer
Expand Down
3 changes: 3 additions & 0 deletions models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class Discriminator(nn.Module):
def __init__(self, input_dims, hidden_dims, output_dims):
"""Init discriminator."""
super(Discriminator, self).__init__()

self.restored = False

self.layer = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.ReLU(),
Expand Down
2 changes: 1 addition & 1 deletion params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# params for source dataset
src_dataset = "MNIST"
src_model_restore = None
src_model_restore = "snapshots/classifier_src-100.pt"

# params for target dataset
tgt_dataset = "USPS"
Expand Down
10 changes: 7 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def init_random_seed(manual_seed):
torch.cuda.manual_seed_all(seed)


def get_data_loader(name):
def get_data_loader(name, train=True):
"""Get data loader by name."""
if name == "MNIST":
return get_mnist()
return get_mnist(train)
elif name == "USPS":
return get_usps()
return get_usps(train)


def init_model(net, restore):
Expand All @@ -71,6 +71,8 @@ def init_model(net, restore):
# restore model weights
if restore is not None and os.path.exists(restore):
net.load_state_dict(torch.load(restore))
net.restored = True
print("Restore model from: {}".format(os.path.abspath(restore)))

# check if cuda is available
if torch.cuda.is_available():
Expand All @@ -86,3 +88,5 @@ def save_model(net, filename):
os.makedirs(params.model_root)
torch.save(net.state_dict(),
os.path.join(params.model_root, filename))
print("save pretrained model to: {}".format(os.path.join(params.model_root,
filename)))

0 comments on commit c98bf21

Please sign in to comment.