Skip to content

Commit

Permalink
still not work 😢
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 22, 2017
1 parent 063eec9 commit e7e6afe
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 36 deletions.
55 changes: 38 additions & 17 deletions core/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ def train_tgt(model_src, model_tgt, model_critic,

# welcome message
print("=== Training encoder for target domain ===")
# set train state for Dropout and BN layers
model_tgt.train()
model_critic.train()

# print model architecture
print(model_tgt)
print(model_critic)

# set train state for Dropout and BN layers
model_tgt.train()
model_critic.train()

# no need to compute gradients for source model
for p in model_src.parameters():
p.requires_grad = False

# setup criterion and optimizer
criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()
optimizer_tgt = optim.Adam(model_tgt.parameters(),
lr=params.c_learning_rate,
betas=(params.beta1, params.beta2))
Expand All @@ -52,46 +54,63 @@ def train_tgt(model_src, model_tgt, model_critic,
# 2.1 train discriminator #
###########################

# make images variable
images_src = make_variable(images_src)
images_tgt = make_variable(images_tgt)

# zero gradients for optimizer
optimizer_tgt.zero_grad()
optimizer_critic.zero_grad()

# extract and concat features
feat_src, _ = model_src(images_src)
feat_tgt, _ = model_tgt(images_tgt)
feat_concat = torch.cat((feat_src, feat_tgt), 0)

label_concat = torch.cat((
make_variable(torch.zeros(feat_concat.size(0) // 2).long()),
make_variable(torch.ones(feat_concat.size(0) // 2).long())
), 0)
# prepare real and fake label
label_src = make_variable(torch.ones(feat_src.size(0)).long())
label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long())
label_concat = torch.cat((label_src, label_tgt), 0)

# compute loss for critic
pred_concat = model_critic(feat_concat)
loss_critic = criterion(pred_concat, label_concat)
loss_critic.backward(retain_graph=True)
loss_critic.backward()

# optimize critic
optimizer_critic.step()

pred_cls = torch.squeeze(pred_concat.max(1)[1])
acc = (pred_cls == label_concat).float().mean()

# train target encoder
############################
# 2.2 train target encoder #
############################

# zero gradients for optimizer
optimizer_tgt.zero_grad()
optimizer_critic.zero_grad()

loss_tgt = criterion(
feat_concat[feat_concat.size(0) // 2:, ...],
make_variable(torch.ones(feat_concat.size(0) // 2).long())
)
# extract and target features
feat_tgt, _ = model_tgt(images_tgt)

# prepare fake labels
label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

# compute loss for target encoder
pred_tgt = model_critic(feat_tgt)
loss_tgt = criterion(pred_tgt, label_tgt)
loss_tgt.backward()

# optimize target encoder
optimizer_tgt.step()

# print step info
#######################
# 2.3 print step info #
#######################
if ((step + 1) % params.log_step == 0):
print("Epoch [{}/{}] Step [{}/{}]:"
"d_loss={:.3f} g_loss={:.3f} acc={:.3f}"
"d_loss={:.5f} g_loss={:.5f} acc={:.5f}"
.format(epoch + 1,
params.num_epochs,
step + 1,
Expand All @@ -100,7 +119,9 @@ def train_tgt(model_src, model_tgt, model_critic,
loss_tgt.data[0],
acc.data[0]))

# save model parameters
#############################
# 2.4 save model parameters #
#############################
if ((epoch + 1) % params.save_step == 0):
if not os.path.exists(params.model_root):
os.makedirs(params.model_root)
Expand Down
36 changes: 30 additions & 6 deletions core/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Pre-train encoder and classifier for source dataset."""

import os

import torch
import torch.nn as nn
import torch.optim as optim

Expand All @@ -12,58 +9,85 @@

def train_src(model, data_loader):
"""Train classifier for source domain."""
####################
# 1. setup network #
####################

# print welcome message and model architecture
print("=== Training classifier for source domain ===")
print(model)

# set train state for Dropout and BN layers
model.train()

# setup criterion and optimizer
optimizer = optim.Adam(model.parameters(),
lr=params.c_learning_rate,
betas=(params.beta1, params.beta2))
criterion = nn.CrossEntropyLoss()

####################
# 2. train network #
####################

for epoch in range(params.num_epochs_pre):
for step, (images, labels) in enumerate(data_loader):
# make images and labels variable
images = make_variable(images)
labels = make_variable(labels.squeeze_())

# zero gradients for optimizer
optimizer.zero_grad()

# compute loss for critic
_, preds = model(images)
loss = criterion(preds, labels)

# optimize source classifier
loss.backward()
optimizer.step()

if ((step + 1) % params.log_step == 0):
# print step info
if ((step + 1) % params.log_step_pre == 0):
print("Epoch [{}/{}] Step [{}/{}]: loss={}"
.format(epoch + 1,
params.num_epochs_pre,
step + 1,
len(data_loader),
loss.data[0]))

if ((epoch + 1) % params.eval_step == 0):
# eval model on test set
if ((epoch + 1) % params.eval_step_pre == 0):
eval_src(model, data_loader, welcome_msg=False)

if ((epoch + 1) % params.save_step == 0):
# save model parameters
if ((epoch + 1) % params.save_step_pre == 0):
save_model(model, "classifier_src-{}.pt".format(epoch + 1))

# # save final model
save_model(model, "classifier_src-final.pt")

return model


def eval_src(model, data_loader, welcome_msg=True):
"""Evaluate classifier for source domain."""
# print welcome message and model architecture
if welcome_msg:
print("=== Evaluating classifier for source domain ===")
print(model)

# set eval state for Dropout and BN layers
model.eval()

# init loss and accuracy
loss = 0
acc = 0

# set loss function
criterion = nn.CrossEntropyLoss()

# evaluate network
for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
labels = make_variable(labels)
Expand Down
2 changes: 1 addition & 1 deletion core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def eval_tgt(model_src, model_tgt, data_loader):
model_tgt.eval()
loss = 0
acc = 0
criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

for (images, labels) in data_loader:
images = make_variable(images, volatile=True)
Expand Down
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
eval_src(model_src, src_data_loader_eval)

# train target encoder by GAN
# if not (model_tgt.restored and params.tgt_model_trained):
# model_tgt = train_tgt(model_src, model_tgt, model_critic,
# src_data_loader, tgt_data_loader)
if not (model_tgt.restored and params.tgt_model_trained):
model_tgt = train_tgt(model_src, model_tgt, model_critic,
src_data_loader, tgt_data_loader)

# eval target encoder on test set of target dataset
# eval_tgt(model_src, model_tgt, tgt_data_loader_eval)
19 changes: 10 additions & 9 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
dataset_std_value = 0.5
dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value)
dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value)
batch_size = 64
batch_size = 50
image_size = 64
num_classes = 10

# params for source dataset
src_dataset = "MNIST"
src_model_restore = None
src_model_trained = False
src_model_restore = "snapshots/classifier_src-final.pt"
src_model_trained = True

# params for target dataset
tgt_dataset = "USPS"
Expand All @@ -22,18 +21,20 @@

# params for setting up models
model_root = "snapshots"
d_input_dims = 50
d_hidden_dims = 512
d_input_dims = 500
d_hidden_dims = 500
d_output_dims = 2
d_model_restore = None

# params for training network
num_gpu = 1
num_epochs_pre = 100
num_epochs = 500
log_step_pre = 20
eval_step_pre = 20
save_step_pre = 100
num_epochs = 20000
log_step = 20
eval_step = 20
save_step = 100
save_step = 1000
manual_seed = None

# params for optimizing models
Expand Down

0 comments on commit e7e6afe

Please sign in to comment.