Skip to content

Commit

Permalink
init train_tgt and eval_tgt
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 19, 2017
1 parent c98bf21 commit 4f354ff
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 3 deletions.
4 changes: 3 additions & 1 deletion core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .adapt import train_tgt
from .pretrain import eval_src, train_src
from .test import eval_tgt

__all__ = (eval_src, train_src)
__all__ = (eval_src, train_src, train_tgt, eval_tgt)
5 changes: 5 additions & 0 deletions core/adapt.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
"""Adversarial adaptation to train target encoder."""


def train_tgt(model_src, model_tgt, data_loader):
"""Train encoder for target domain."""
pass
5 changes: 5 additions & 0 deletions core/test.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
"""Test script to classify target data."""


def eval_tgt(model_src, model_tgt, data_loader):
"""Evaluation for target encoder by source classifier on target dataset."""
pass
13 changes: 11 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main script for ADDA."""

import params
from core import eval_src, train_src
from core import eval_src, eval_tgt, train_src, train_tgt
from models import Classifier, Discriminator
from utils import get_data_loader, init_model, init_random_seed

Expand All @@ -13,6 +13,7 @@
src_data_loader = get_data_loader(params.src_dataset)
src_data_loader_eval = get_data_loader(params.src_dataset, train=False)
tgt_data_loader = get_data_loader(params.tgt_dataset)
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)

# load models
model_src = init_model(net=Classifier(num_channels=params.num_channels,
Expand All @@ -30,6 +31,14 @@
output_dims=params.d_output_dims),
restore=params.d_model_restore)

if not model_src.restored:
# train and eval source model
if not (model_src.restored and params.src_model_trained):
model_src = train_src(model_src, src_data_loader)
eval_src(model_src, src_data_loader_eval)

# train target encoder by GAN
if not (model_tgt.restored and params.tgt_model_trained):
model_tgt = train_tgt(model_src, model_tgt, tgt_data_loader)

# eval target encoder on test set of target dataset
eval_tgt(model_src, model_tgt, tgt_data_loader_eval)
2 changes: 2 additions & 0 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# params for source dataset
src_dataset = "MNIST"
src_model_restore = "snapshots/classifier_src-100.pt"
src_model_trained = True

# params for target dataset
tgt_dataset = "USPS"
tgt_model_restore = None
tgt_model_trained = False

# params for setting up models
model_root = "snapshots"
Expand Down

0 comments on commit 4f354ff

Please sign in to comment.