Skip to content

Commit

Permalink
seperate encoder and classifier of LeNet
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 22, 2017
1 parent e7e6afe commit 34a84f6
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 63 deletions.
41 changes: 20 additions & 21 deletions core/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from utils import make_variable


def train_tgt(model_src, model_tgt, model_critic,
def train_tgt(src_encoder, tgt_encoder, critic,
src_data_loader, tgt_data_loader):
"""Train encoder for target domain."""
####################
Expand All @@ -21,23 +21,19 @@ def train_tgt(model_src, model_tgt, model_critic,
print("=== Training encoder for target domain ===")

# print model architecture
print(model_tgt)
print(model_critic)
print(tgt_encoder)
print(critic)

# set train state for Dropout and BN layers
model_tgt.train()
model_critic.train()

# no need to compute gradients for source model
for p in model_src.parameters():
p.requires_grad = False
tgt_encoder.train()
critic.train()

# setup criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_tgt = optim.Adam(model_tgt.parameters(),
optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
lr=params.c_learning_rate,
betas=(params.beta1, params.beta2))
optimizer_critic = optim.Adam(model_tgt.parameters(),
optimizer_critic = optim.Adam(critic.parameters(),
lr=params.d_learning_rate,
betas=(params.beta1, params.beta2))
len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
Expand All @@ -59,21 +55,22 @@ def train_tgt(model_src, model_tgt, model_critic,
images_tgt = make_variable(images_tgt)

# zero gradients for optimizer
optimizer_tgt.zero_grad()
optimizer_critic.zero_grad()

# extract and concat features
feat_src, _ = model_src(images_src)
feat_tgt, _ = model_tgt(images_tgt)
feat_src = src_encoder(images_src)
feat_tgt = tgt_encoder(images_tgt)
feat_concat = torch.cat((feat_src, feat_tgt), 0)

# predict on discriminator
pred_concat = critic(feat_concat.detach())

# prepare real and fake label
label_src = make_variable(torch.ones(feat_src.size(0)).long())
label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long())
label_concat = torch.cat((label_src, label_tgt), 0)

# compute loss for critic
pred_concat = model_critic(feat_concat)
loss_critic = criterion(pred_concat, label_concat)
loss_critic.backward()

Expand All @@ -88,17 +85,19 @@ def train_tgt(model_src, model_tgt, model_critic,
############################

# zero gradients for optimizer
optimizer_tgt.zero_grad()
optimizer_critic.zero_grad()
optimizer_tgt.zero_grad()

# extract and target features
feat_tgt, _ = model_tgt(images_tgt)
feat_tgt = tgt_encoder(images_tgt)

# predict on discriminator
pred_tgt = critic(feat_tgt)

# prepare fake labels
label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

# compute loss for target encoder
pred_tgt = model_critic(feat_tgt)
loss_tgt = criterion(pred_tgt, label_tgt)
loss_tgt.backward()

Expand All @@ -125,9 +124,9 @@ def train_tgt(model_src, model_tgt, model_critic,
if ((epoch + 1) % params.save_step == 0):
if not os.path.exists(params.model_root):
os.makedirs(params.model_root)
torch.save(model_critic.state_dict(), os.path.join(
torch.save(critic.state_dict(), os.path.join(
params.model_root,
"ADDA-critic-{}.pt".format(epoch + 1)))
torch.save(model_tgt.state_dict(), os.path.join(
torch.save(tgt_encoder.state_dict(), os.path.join(
params.model_root,
"ADDA-target-{}.pt".format(epoch + 1)))
"ADDA-target-encoder-{}.pt".format(epoch + 1)))
38 changes: 23 additions & 15 deletions core/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,26 @@
from utils import make_variable, save_model


def train_src(model, data_loader):
def train_src(encoder, classifier, data_loader):
"""Train classifier for source domain."""
####################
# 1. setup network #
####################

# print welcome message and model architecture
print("=== Training classifier for source domain ===")
print(model)
print(encoder)
print(classifier)

# set train state for Dropout and BN layers
model.train()
encoder.train()
classifier.train()

# setup criterion and optimizer
optimizer = optim.Adam(model.parameters(),
lr=params.c_learning_rate,
betas=(params.beta1, params.beta2))
optimizer = optim.Adam(
list(encoder.parameters()) + list(classifier.parameters()),
lr=params.c_learning_rate,
betas=(params.beta1, params.beta2))
criterion = nn.CrossEntropyLoss()

####################
Expand All @@ -40,7 +43,7 @@ def train_src(model, data_loader):
optimizer.zero_grad()

# compute loss for critic
_, preds = model(images)
preds = classifier(encoder(images))
loss = criterion(preds, labels)

# optimize source classifier
Expand All @@ -58,27 +61,32 @@ def train_src(model, data_loader):

# eval model on test set
if ((epoch + 1) % params.eval_step_pre == 0):
eval_src(model, data_loader, welcome_msg=False)
eval_src(encoder, classifier, data_loader, welcome_msg=False)

# save model parameters
if ((epoch + 1) % params.save_step_pre == 0):
save_model(model, "classifier_src-{}.pt".format(epoch + 1))
save_model(encoder, "ADDA-source-encoder-{}.pt".format(epoch + 1))
save_model(
classifier, "ADDA-source-classifier-{}.pt".format(epoch + 1))

# # save final model
save_model(model, "classifier_src-final.pt")
save_model(encoder, "ADDA-source-encoder-final.pt")
save_model(classifier, "ADDA-source-classifier-final.pt")

return model
return encoder, classifier


def eval_src(model, data_loader, welcome_msg=True):
def eval_src(encoder, classifier, data_loader, welcome_msg=True):
"""Evaluate classifier for source domain."""
# print welcome message and model architecture
if welcome_msg:
print("=== Evaluating classifier for source domain ===")
print(model)
print(encoder)
print(classifier)

# set eval state for Dropout and BN layers
model.eval()
encoder.eval()
classifier.eval()

# init loss and accuracy
loss = 0
Expand All @@ -92,7 +100,7 @@ def eval_src(model, data_loader, welcome_msg=True):
images = make_variable(images, volatile=True)
labels = make_variable(labels)

_, preds = model(images)
preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]

pred_cls = preds.data.max(1)[1]
Expand Down
33 changes: 19 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import params
from core import eval_src, eval_tgt, train_src, train_tgt
from models import Discriminator, LeNet
from models import Discriminator, LeNetClassifier, LeNetEncoder
from utils import get_data_loader, init_model, init_random_seed

if __name__ == '__main__':
Expand All @@ -16,22 +16,27 @@
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)

# load models
model_src = init_model(net=LeNet(), restore=params.src_model_restore)
model_tgt = init_model(net=LeNet(), restore=params.tgt_model_restore)
model_critic = init_model(Discriminator(input_dims=params.d_input_dims,
hidden_dims=params.d_hidden_dims,
output_dims=params.d_output_dims),
restore=params.d_model_restore)
src_encoder = init_model(net=LeNetEncoder(),
restore=params.src_encoder_restore)
classifier_src = init_model(net=LeNetClassifier(),
restore=params.src_classifier_restore)
tgt_encoder = init_model(net=LeNetEncoder(),
restore=params.tgt_encoder_restore)
critic = init_model(Discriminator(input_dims=params.d_input_dims,
hidden_dims=params.d_hidden_dims,
output_dims=params.d_output_dims),
restore=params.d_model_restore)

# train and eval source model
if not (model_src.restored and params.src_model_trained):
model_src = train_src(model_src, src_data_loader)
eval_src(model_src, src_data_loader_eval)
if not (src_encoder.restored and classifier_src.restored and
params.src_model_trained):
model_src = train_src(src_encoder, classifier_src, src_data_loader)
eval_src(src_encoder, classifier_src, src_data_loader_eval)

# train target encoder by GAN
if not (model_tgt.restored and params.tgt_model_trained):
model_tgt = train_tgt(model_src, model_tgt, model_critic,
src_data_loader, tgt_data_loader)
# if not (tgt_encoder.restored and params.tgt_encoder_trained):
# model_tgt = train_tgt(src_encoder, tgt_encoder, critic,
# src_data_loader, tgt_data_loader)

# eval target encoder on test set of target dataset
# eval_tgt(model_src, model_tgt, tgt_data_loader_eval)
# eval_tgt(classifier_src, tgt_encoder, tgt_data_loader_eval)
4 changes: 2 additions & 2 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .discriminator import Discriminator
from .lenet import LeNet
from .lenet import LeNetClassifier, LeNetEncoder

__all__ = (LeNet, Discriminator)
__all__ = (LeNetClassifier, LeNetEncoder, Discriminator)
2 changes: 2 additions & 0 deletions models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self, input_dims, hidden_dims, output_dims):
self.layer = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.ReLU(),
nn.Linear(hidden_dims, hidden_dims),
nn.ReLU(),
nn.Linear(hidden_dims, output_dims),
nn.LogSoftmax()
)
Expand Down
29 changes: 21 additions & 8 deletions models/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from torch import nn


class LeNet(nn.Module):
"""LeNet model for source domain."""
class LeNetEncoder(nn.Module):
"""LeNet encoder model for ADDA."""

def __init__(self):
"""Init LeNet."""
super(LeNet, self).__init__()
"""Init LeNet encoder."""
super(LeNetEncoder, self).__init__()

self.restored = False

Expand All @@ -24,16 +24,29 @@ def __init__(self):
# input [20 x 12 x 12]
# output [50 x 4 x 4]
nn.Conv2d(20, 50, kernel_size=5),
nn.Dropout2d(),
nn.MaxPool2d(kernel_size=2),
nn.ReLU()
)
self.fc1 = nn.Linear(50 * 4 * 4, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, input):
"""Forward the LeNet."""
conv_out = self.encoder(input)
feature = self.fc1(conv_out.view(-1, 50 * 4 * 4))
out = F.dropout(F.relu(feature), training=self.training)
feat = self.fc1(conv_out.view(-1, 50 * 4 * 4))
return feat


class LeNetClassifier(nn.Module):
"""LeNet classifier model for ADDA."""

def __init__(self):
"""Init LeNet encoder."""
super(LeNetClassifier, self).__init__()
self.fc2 = nn.Linear(500, 10)

def forward(self, feat):
"""Forward the LeNet classifier."""
out = F.dropout(F.relu(feat), training=self.training)
out = self.fc2(out)
return feature, out
return out
7 changes: 4 additions & 3 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

# params for source dataset
src_dataset = "MNIST"
src_model_restore = "snapshots/classifier_src-final.pt"
src_model_trained = True
src_encoder_restore = None
src_classifier_restore = None
src_model_trained = False

# params for target dataset
tgt_dataset = "USPS"
tgt_model_restore = None
tgt_encoder_restore = None
tgt_model_trained = False

# params for setting up models
Expand Down

0 comments on commit 34a84f6

Please sign in to comment.