Skip to content

Commit

Permalink
works but the result isn't as expected.
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 22, 2017
1 parent 34a84f6 commit aad4d3c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 45 deletions.
17 changes: 8 additions & 9 deletions core/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@ def train_tgt(src_encoder, tgt_encoder, critic,
# 1. setup network #
####################

# welcome message
print("=== Training encoder for target domain ===")

# print model architecture
print(tgt_encoder)
print(critic)

# set train state for Dropout and BN layers
tgt_encoder.train()
critic.train()
Expand Down Expand Up @@ -122,11 +115,17 @@ def train_tgt(src_encoder, tgt_encoder, critic,
# 2.4 save model parameters #
#############################
if ((epoch + 1) % params.save_step == 0):
if not os.path.exists(params.model_root):
os.makedirs(params.model_root)
torch.save(critic.state_dict(), os.path.join(
params.model_root,
"ADDA-critic-{}.pt".format(epoch + 1)))
torch.save(tgt_encoder.state_dict(), os.path.join(
params.model_root,
"ADDA-target-encoder-{}.pt".format(epoch + 1)))

torch.save(critic.state_dict(), os.path.join(
params.model_root,
"ADDA-critic-final.pt"))
torch.save(tgt_encoder.state_dict(), os.path.join(
params.model_root,
"ADDA-target-encoder-final.pt"))
return tgt_encoder
15 changes: 2 additions & 13 deletions core/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ def train_src(encoder, classifier, data_loader):
# 1. setup network #
####################

# print welcome message and model architecture
print("=== Training classifier for source domain ===")
print(encoder)
print(classifier)

# set train state for Dropout and BN layers
encoder.train()
classifier.train()
Expand Down Expand Up @@ -61,7 +56,7 @@ def train_src(encoder, classifier, data_loader):

# eval model on test set
if ((epoch + 1) % params.eval_step_pre == 0):
eval_src(encoder, classifier, data_loader, welcome_msg=False)
eval_src(encoder, classifier, data_loader)

# save model parameters
if ((epoch + 1) % params.save_step_pre == 0):
Expand All @@ -76,14 +71,8 @@ def train_src(encoder, classifier, data_loader):
return encoder, classifier


def eval_src(encoder, classifier, data_loader, welcome_msg=True):
def eval_src(encoder, classifier, data_loader):
"""Evaluate classifier for source domain."""
# print welcome message and model architecture
if welcome_msg:
print("=== Evaluating classifier for source domain ===")
print(encoder)
print(classifier)

# set eval state for Dropout and BN layers
encoder.eval()
classifier.eval()
Expand Down
15 changes: 10 additions & 5 deletions core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
from utils import make_variable


def eval_tgt(model_src, model_tgt, data_loader):
def eval_tgt(encoder, classifier, data_loader):
"""Evaluation for target encoder by source classifier on target dataset."""
print("=== Evaluating classifier for encoded target domain ===")
model_src.eval()
model_tgt.eval()
# set eval state for Dropout and BN layers
encoder.eval()
classifier.eval()

# init loss and accuracy
loss = 0
acc = 0

# set loss function
criterion = nn.CrossEntropyLoss()

# evaluate network
for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels).squeeze_()

_, preds = model_tgt(images)
preds = classifier(encoder(images))
loss += criterion(preds, labels).data[0]

pred_cls = preds.data.max(1)[1]
Expand Down
39 changes: 30 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# load models
src_encoder = init_model(net=LeNetEncoder(),
restore=params.src_encoder_restore)
classifier_src = init_model(net=LeNetClassifier(),
src_classifier = init_model(net=LeNetClassifier(),
restore=params.src_classifier_restore)
tgt_encoder = init_model(net=LeNetEncoder(),
restore=params.tgt_encoder_restore)
Expand All @@ -27,16 +27,37 @@
output_dims=params.d_output_dims),
restore=params.d_model_restore)

# train and eval source model
if not (src_encoder.restored and classifier_src.restored and
# train source model
print("=== Training classifier for source domain ===")
print(">>> Source Encoder <<<")
print(src_encoder)
print(">>> Source Classifier <<<")
print(src_classifier)

if not (src_encoder.restored and src_classifier.restored and
params.src_model_trained):
model_src = train_src(src_encoder, classifier_src, src_data_loader)
eval_src(src_encoder, classifier_src, src_data_loader_eval)
src_encoder, src_classifier = train_src(
src_encoder, src_classifier, src_data_loader)

# eval source model
print("=== Evaluating classifier for source domain ===")
eval_src(src_encoder, src_classifier, src_data_loader_eval)

# train target encoder by GAN
# if not (tgt_encoder.restored and params.tgt_encoder_trained):
# model_tgt = train_tgt(src_encoder, tgt_encoder, critic,
# src_data_loader, tgt_data_loader)
print("=== Training encoder for target domain ===")
print(">>> Target Encoder <<<")
print(tgt_encoder)
print(">>> Critic <<<")
print(critic)

if not (tgt_encoder.restored and critic.restored and
params.tgt_model_trained):
tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,
src_data_loader, tgt_data_loader)

# eval target encoder on test set of target dataset
# eval_tgt(classifier_src, tgt_encoder, tgt_data_loader_eval)
print("=== Evaluating classifier for encoded target domain ===")
print(">>> source only <<<")
eval_tgt(src_encoder, src_classifier, tgt_data_loader_eval)
print(">>> domain adaption <<<")
eval_tgt(tgt_encoder, src_classifier, tgt_data_loader_eval)
18 changes: 9 additions & 9 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,31 @@

# params for source dataset
src_dataset = "MNIST"
src_encoder_restore = None
src_classifier_restore = None
src_model_trained = False
src_encoder_restore = "snapshots/ADDA-source-encoder-final.pt"
src_classifier_restore = "snapshots/ADDA-source-classifier-final.pt"
src_model_trained = True

# params for target dataset
tgt_dataset = "USPS"
tgt_encoder_restore = None
tgt_model_trained = False
tgt_encoder_restore = "snapshots/ADDA-target-encoder-500.pt"
tgt_model_trained = True

# params for setting up models
model_root = "snapshots"
d_input_dims = 500
d_hidden_dims = 500
d_output_dims = 2
d_model_restore = None
d_model_restore = "snapshots/ADDA-critic-500.pt"

# params for training network
num_gpu = 1
num_epochs_pre = 100
log_step_pre = 20
eval_step_pre = 20
save_step_pre = 100
num_epochs = 20000
log_step = 20
save_step = 1000
num_epochs = 2000
log_step = 100
save_step = 100
manual_seed = None

# params for optimizing models
Expand Down

0 comments on commit aad4d3c

Please sign in to comment.