Skip to content

Commit

Permalink
Draw the sketch for ADDA.
Browse files Browse the repository at this point in the history
  • Loading branch information
corenel committed Aug 17, 2017
1 parent 538965f commit ebb5a76
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 2 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# pytorch dataset
data
model

# trash
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# pytorch-adda
PyTorch implementation for Adversarial Discriminative Domain Adaptation
# PyTorch-ADDA
A PyTorch implementation for [Adversarial Discriminative Domain Adaptation](https://arxiv.org/abs/1702.05464).
Empty file added core/__init__.py
Empty file.
1 change: 1 addition & 0 deletions core/adapt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Adversarial adaptation to train target encoder."""
1 change: 1 addition & 0 deletions core/pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Pre-train encoder and classifier for source dataset."""
1 change: 1 addition & 0 deletions core/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test script to classify target data."""
Empty file added datasets/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions datasets/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Dataset setting and data loader for MNIST."""


import torch
from torchvision import datasets, transforms

import params

# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])

# dataset and data loader
mnist_dataset = datasets.mnist(root=params.data_root,
transform=pre_process,
download=True)

mnist_data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset,
batch_size=params.batch_size,
shuffle=True)


def get_mnist():
"""Inf data iterator."""
while True:
for images, _ in mnist_data_loader:
yield images
1 change: 1 addition & 0 deletions datasets/usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Dataset setting and data loader for USPS."""
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Main script for ADDA."""
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utilities for ADDA."""

0 comments on commit ebb5a76

Please sign in to comment.