Skip to content

Commit

Permalink
testing
Browse files Browse the repository at this point in the history
  • Loading branch information
bkj committed Feb 17, 2018
1 parent 9191207 commit bf90e86
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 41 deletions.
111 changes: 75 additions & 36 deletions basenet/basenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
from .helpers import to_numpy
from .lr import LRSchedule

# --
# Helpers

def _set_train(x, mode):
x.training = False if getattr(x, 'frozen', False) else mode
for module in x.children():
_set_train(module, mode)

return x

# --
# Model

class BaseNet(nn.Module):

def __init__(self, loss_fn=F.cross_entropy, verbose=False):
Expand All @@ -36,85 +49,108 @@ def init_optimizer(self, opt, params, lr_scheduler=None, **kwargs):
assert 'lr' not in kwargs, "BaseWrapper.init_optimizer: can't set LR and lr_scheduler"
self.lr_scheduler = lr_scheduler
self.opt = opt(params, lr=self.lr_scheduler(0), **kwargs)
self.set_progress(0)
else:
self.lr_scheduler = None
self.opt = opt(params, **kwargs)

def set_progress(self, progress):
self.progress = progress
self.epoch = np.floor(progress)
if self.lr_scheduler is not None:
self.progress, self.lr = progress, self.lr_scheduler(progress)
self.lr = self.lr_scheduler(progress)
LRSchedule.set_lr(self.opt, self.lr)

def zero_progress(self):
self.epoch = 0
self.set_progress(0.0)
# --
# Training states

def train(self, mode=True):
""" have to override this function to allow more finegrained control """
return _set_train(self, mode=mode)

# --
# Batch steps

def train_batch(self, data, target):
data, target = Variable(data.cuda()), Variable(target.cuda())

_ = self.train()

self.opt.zero_grad()
output = self(data)
loss = self.loss_fn(output, target)
loss.backward()
self.opt.step()
return output

return output, float(loss)

def eval_batch(self, data, target):
data, target = Variable(data.cuda(), volatile=True), Variable(target.cuda())

_ = self.eval()

output = self(data)
return (to_numpy(output).argmax(axis=1) == to_numpy(target)).mean()
loss = self.loss_fn(output, target)

return output, float(loss)

# --
# Epoch steps

def train_epoch(self, dataloaders, num_batches=np.inf):
def train_epoch(self, dataloaders, mode='train', num_batches=np.inf):
assert self.opt is not None, "BaseWrapper: self.opt is None"

loader = dataloaders['train']
gen = enumerate(loader)
if self.verbose:
gen = tqdm(gen, total=len(loader))

correct, total = 0, 0
for batch_idx, (data, target) in gen:
data, target = Variable(data.cuda()), Variable(target.cuda())

self.set_progress(self.epoch + batch_idx / len(loader))

output = self.train_batch(data, target)

correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
total += data.shape[0]
loader = dataloaders[mode]
if loader is None:
return None
else:
gen = enumerate(loader)
if self.verbose:
gen = tqdm(gen, total=len(loader))

if batch_idx > num_batches:
break
avg_mom = 0.98
avg_loss = 0.0
correct, total, loss_hist = 0, 0, []
for batch_idx, (data, target) in gen:
self.set_progress(self.epoch + batch_idx / len(loader))

output, loss = self.train_batch(data, target)
loss_hist.append(loss)

avg_loss = avg_loss * avg_mom + loss * (1 - avg_mom)
debias_loss = avg_loss / (1 - avg_mom ** (batch_idx + 1))

correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
total += data.shape[0]

if batch_idx > num_batches:
break

if self.verbose:
gen.set_postfix(acc=correct / total)

if self.verbose:
gen.set_postfix(acc=correct / total)
self.epoch += 1
return {
"acc" : correct / total,
"loss" : np.hstack(loss_hist),
"debias_loss" : debias_loss,
}

self.epoch += 1
return correct / total

def eval_epoch(self, dataloaders, mode='val', num_batches=np.inf):
assert self.opt is not None, "BaseWrapper: self.opt is None"

loader = dataloaders[mode]
if loader is None:
return None
else:
_ = self.eval()
correct, total = 0, 0

gen = enumerate(loader)
if self.verbose:
gen = tqdm(gen, total=len(loader))

correct, total, loss_hist = 0, 0, []
for batch_idx, (data, target) in gen:
data = Variable(data.cuda(), volatile=True)

output = self(data)
output, loss = self.eval_batch(data, target)
loss_hist.append(loss)

correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
total += data.shape[0]
Expand All @@ -125,7 +161,10 @@ def eval_epoch(self, dataloaders, mode='val', num_batches=np.inf):
if self.verbose:
gen.set_postfix(acc=correct / total)

return correct / total
return {
"acc" : correct / total,
"loss" : np.hstack(loss_hist),
}


class BaseWrapper(BaseNet):
Expand Down
47 changes: 43 additions & 4 deletions basenet/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,58 @@
from __future__ import print_function, division

import numpy as np
import random

import torch
from torch import nn
from torch.autograd import Variable

# --
# Utils

def set_seeds(seed=100):
_ = np.random.seed(seed)
_ = torch.manual_seed(seed + 123)
_ = torch.cuda.manual_seed(seed + 456)
_ = random.seed(seed + 789)

def to_numpy(x):
if isinstance(x, Variable):
return to_numpy(x.data)

return x.cpu().numpy() if x.is_cuda else x.numpy()

# --
# From `fastai`

def get_children(m):
return m if isinstance(m, (list, tuple)) else list(m.children())

# def apply_leaf(model, fn):
# children = get_children(model)
# if isinstance(model, nn.Module):
# fn(model)

# if len(children) > 0:
# for layer in children:
# apply_leaf(layer, fn)

# def _set_freeze(x, val):
# p.frozen = val
# for p in x.parameters():
# p.requires_grad = val

# def set_freeze(model, val):
# apply_leaf(model, lambda x: _set_freeze(x, val))

def set_freeze(x, mode):
x.frozen = mode
for p in x.parameters():
p.requires_grad = not mode

for module in x.children():
set_freeze(module, mode)




def set_seeds(seed=100):
np.random.seed(seed)
_ = torch.manual_seed(seed + 123)
_ = torch.cuda.manual_seed(seed + 456)
1 change: 0 additions & 1 deletion basenet/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def f(progress):

@staticmethod
def sgdr(lr_init=0.1, period_length=50, lr_min=0, t_mult=1, **kwargs):
print('sgdr: period_length=%d | lr_init=%s' % (period_length, str(lr_init)), file=sys.stderr)
def f(progress):
""" SGDR learning rate annealing """
if t_mult > 1:
Expand Down

0 comments on commit bf90e86

Please sign in to comment.