Skip to content

Commit

Permalink
dawn
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Apr 21, 2018
1 parent 15ce63c commit 49b2b61
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 50 deletions.
34 changes: 17 additions & 17 deletions basenet/basenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,34 +137,34 @@ def train_epoch(self, dataloaders, mode='train', num_batches=np.inf):
if self.verbose:
gen = tqdm(gen, total=len(loader), desc='train_epoch:%s' % mode)

avg_mom = 0.98
avg_loss = 0.0
# avg_mom = 0.98
# avg_loss = 0.0

correct, total, loss_hist = 0, 0, []
for batch_idx, (data, target) in gen:
self.set_progress(self.epoch + batch_idx / len(loader))

output, loss = self.train_batch(data, target)
loss_hist.append(loss)
# loss_hist.append(loss)

avg_loss = avg_loss * avg_mom + loss * (1 - avg_mom)
debias_loss = avg_loss / (1 - avg_mom ** (batch_idx + 1))
# avg_loss = avg_loss * avg_mom + loss * (1 - avg_mom)
# debias_loss = avg_loss / (1 - avg_mom ** (batch_idx + 1))

correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
total += data.shape[0]
# correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
# total += data.shape[0]

if batch_idx > num_batches:
break
# if batch_idx > num_batches:
# break

if self.verbose:
gen.set_postfix(acc=correct / total)
# if self.verbose:
# gen.set_postfix(acc=correct / total)

self.epoch += 1
return {
"acc" : correct / total,
"loss" : np.hstack(loss_hist),
"debias_loss" : debias_loss,
}
# return {
# "acc" : correct / total,
# "loss" : np.hstack(loss_hist),
# "debias_loss" : debias_loss,
# }

def eval_epoch(self, dataloaders, mode='val', num_batches=np.inf):

Expand All @@ -182,7 +182,7 @@ def eval_epoch(self, dataloaders, mode='val', num_batches=np.inf):
output, loss = self.eval_batch(data, target)
loss_hist.append(loss)

correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
correct += (to_numpy(output.float()).argmax(axis=1) == to_numpy(target)).sum()
total += data.shape[0]

if batch_idx > num_batches:
Expand Down
32 changes: 26 additions & 6 deletions basenet/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def f(progress):

return f

@staticmethod
def linear_cycle(lr_init=0.1, epochs=10, low_lr=0.005, extra=5, **kwargs):
def f(progress):
if progress < epochs / 2:
return 2 * lr_init * (1 - float(epochs - progress) / epochs)
elif progress <= epochs:
return low_lr + 2 * lr_init * float(epochs - progress) / epochs
elif progress <= epochs + extra:
return low_lr * float(extra - (progress - epochs)) / extra
else:
return low_lr / 10

return f

@staticmethod
def cyclical(lr_init=0.1, lr_burn_in=0.05, epochs=10, **kwargs):
def f(progress):
Expand Down Expand Up @@ -214,6 +228,12 @@ def get_optimal_lr(lr_hist, loss_hist, c=10, burnin=5):
# _ = plt.plot(lrs[:,1])
# show_plot()

# Linear cycle
lr = LRSchedule.linear_cycle(epochs=30, lr_init=0.1, extra=10)
lrs = np.vstack([lr(i) for i in np.linspace(0, 40, 1000)])
_ = plt.plot(lrs)
show_plot()

# # Cyclical
# lr = LRSchedule.cyclical(epochs=30, lr_init=np.array([1, 2]))
# lrs = np.vstack([lr(i) for i in np.linspace(0, 30, 1000)])
Expand All @@ -229,9 +249,9 @@ def get_optimal_lr(lr_hist, loss_hist, c=10, burnin=5):
# show_plot()

# exponential increase (for setting learning rates)
lr = LRSchedule.exponential_increase(lr_init=np.array([1e-5, 1e-4]), lr_max=10, num_steps=100)
lrs = np.vstack([lr(i) for i in np.linspace(0, 100, 1000)])
_ = plt.plot(lrs[:,0])
_ = plt.plot(lrs[:,1])
_ = plt.yscale('log')
show_plot()
# lr = LRSchedule.exponential_increase(lr_init=np.array([1e-5, 1e-4]), lr_max=10, num_steps=100)
# lrs = np.vstack([lr(i) for i in np.linspace(0, 100, 1000)])
# _ = plt.plot(lrs[:,0])
# _ = plt.plot(lrs[:,1])
# _ = plt.yscale('log')
# show_plot()
56 changes: 37 additions & 19 deletions examples/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import json
import argparse
import numpy as np
from time import time
from PIL import Image

from basenet import BaseNet
from basenet.lr import LRSchedule
Expand All @@ -33,10 +35,15 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--lr-schedule', type=str, default='linear')
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--extra', type=int, default=5)
parser.add_argument('--burnout', type=int, default=5)
parser.add_argument('--lr-schedule', type=str, default='linear_cycle')
parser.add_argument('--lr-init', type=float, default=0.1)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--seed', type=int, default=789)
parser.add_argument('--download', action="store_true")
return parser.parse_args()

Expand All @@ -55,7 +62,11 @@ def parse_args():
}

transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Lambda(lambda x: np.asarray(x)),
transforms.Lambda(lambda x: np.pad(x, [(4, 4), (4, 4), (0, 0)], mode='reflect')),
transforms.Lambda(lambda x: Image.fromarray(x)),
transforms.RandomCrop(32),

transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(cifar10_stats['mean'], cifar10_stats['std']),
Expand All @@ -74,17 +85,17 @@ def parse_args():

trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=128,
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
num_workers=16,
pin_memory=True,
)

testloader = torch.utils.data.DataLoader(
testset,
batch_size=256,
batch_size=512,
shuffle=False,
num_workers=8,
num_workers=16,
pin_memory=True,
)

Expand Down Expand Up @@ -136,7 +147,7 @@ def __init__(self, num_blocks=[2, 2, 2, 2], num_classes=10):
self._make_layer(64, 64, num_blocks[0], stride=1),
self._make_layer(64, 128, num_blocks[1], stride=2),
self._make_layer(128, 256, num_blocks[2], stride=2),
self._make_layer(256, 512, num_blocks[3], stride=2),
self._make_layer(256, 256, num_blocks[3], stride=2),
)

self.classifier = nn.Linear(512, num_classes)
Expand All @@ -152,12 +163,18 @@ def _make_layer(self, in_channels, out_channels, num_blocks, stride):
return nn.Sequential(*layers)

def forward(self, x):
x = self.prep(x)
x = self.prep(x.half())

x = self.layers(x)

x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
x_avg = F.adaptive_avg_pool2d(x, (1, 1))
x_avg = x_avg.view(x_avg.size(0), -1)

x_max = F.adaptive_max_pool2d(x, (1, 1))
x_max = x_max.view(x_max.size(0), -1)

x = torch.cat([x_avg, x_max], dim=-1)

x = self.classifier(x)

return x
Expand All @@ -167,35 +184,36 @@ def forward(self, x):

print('cifar10.py: initializing model...', file=sys.stderr)

model = ResNet18().cuda()
model = ResNet18().cuda().half()
print(model, file=sys.stderr)
model.verbose = True

# --
# Initialize optimizer

print('cifar10.py: initializing optimizer...', file=sys.stderr)

lr_scheduler = getattr(LRSchedule, args.lr_schedule)(lr_init=args.lr_init, epochs=args.epochs)
lr_scheduler = getattr(LRSchedule, args.lr_schedule)(lr_init=args.lr_init, epochs=args.epochs, extra=args.extra)
model.init_optimizer(
opt=torch.optim.SGD,
params=model.parameters(),
lr_scheduler=lr_scheduler,
momentum=0.9,
weight_decay=5e-4,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True,
)

# --
# Train

print('cifar10.py: training...', file=sys.stderr)
for epoch in range(args.epochs):
t = time()
for epoch in range(args.epochs + args.extra + args.burnout):
train = model.train_epoch(dataloaders, mode='train')
test = model.eval_epoch(dataloaders, mode='test')
print(json.dumps({
"epoch" : int(epoch),
"lr" : model.lr,
"train_acc" : float(train['acc']),
"test_acc" : float(test['acc']),
"time" : time() - t,
}))
sys.stdout.flush()
9 changes: 1 addition & 8 deletions examples/cifar10.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,4 @@

# cifar10.sh

mkdir -p results

python cifar10.py \
--epochs 50 \
--lr-schedule linear \
--lr-init 0.1 \
--download \
--seed 123 > results/cifar10-linear.jl
time python cifar10.py > cifar10.jl
14 changes: 14 additions & 0 deletions examples/dawn/basenet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"version": "v1.0",
"author": "bkj",
"authorEmail": "ben@canfield.io",
"framework": "pytorch",
"codeURL": "https://github.com/bkj/basenet/tree/f1c30a95263346231cb7b0a6b77de4a3f44bc6b7/examples",
"model": "Resnet18 + minor modifications",
"hardware": "V100 (AWS p3.2xlarge)",
"costPerHour": 3.060,
"timestamp": "2018-04-20",
"misc": {
"comments" : "Hit 0.94 threshold in 4/5 runs. Reporting median run here."
}
}
41 changes: 41 additions & 0 deletions examples/dawn/basenet_dawn.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
epoch hours top1Accuracy
0 0.003508641587363349 55.55
1 0.006513937910397847 65.62
2 0.009536764687962003 74.83
3 0.012568767666816711 75.66000000000001
4 0.015597014096048143 79.17999999999999
5 0.018631390664312575 79.14999999999999
6 0.0216702620850669 83.65
7 0.024692272411452398 79.71000000000001
8 0.027740566465589735 83.78999999999999
9 0.030775602261225384 78.21000000000001
10 0.033793951206737095 85.15
11 0.03682645208305783 84.95
12 0.039851660993364124 81.04
13 0.04292301966084374 83.14
14 0.04596217632293701 82.11
15 0.04898762815528446 79.17999999999999
16 0.05204661183887058 84.98
17 0.055087420211897956 84.77
18 0.058125914070341324 85.28999999999999
19 0.061184588935640126 84.87
20 0.0642281475994322 85.64
21 0.06726788024107615 87.69
22 0.0703298607137468 89.58
23 0.07337968620989058 88.22
24 0.07640756865342459 88.81
25 0.07944716102547116 91.18
26 0.08250825656784905 91.39
27 0.08555183801386092 92.10000000000001
28 0.08858485188749102 93.14
29 0.09160958170890808 93.99
30 0.0946198488606347 94.08
31 0.09764599925941891 94.15
32 0.10066269311639997 94.19
33 0.1036933634016249 94.28
34 0.10669603122605217 94.3
35 0.10971194055345324 94.23
36 0.11277156300014919 94.31
37 0.115811897582478 94.3
38 0.11884934445222219 94.31
39 0.12188066297107272 94.34
25 changes: 25 additions & 0 deletions examples/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python

import sys
import json
import numpy as np
from rsub import *
from matplotlib import pyplot as plt

def smart_json_loads(x):
try:
return json.loads(x)['test_acc']
except:
pass

all_data = []
for p in sys.argv[1:]:
data = list(filter(None, map(smart_json_loads, open(p))))
_ = plt.plot(data, alpha=0.75, label=p)

_ = plt.legend(loc='lower right')
_ = plt.axhline(0.9, c='grey', alpha=0.5)
_ = plt.axhline(0.94, c='grey', alpha=0.5)
_ = plt.ylim(0, 1)
_ = plt.xlim(0, 40)
show_plot()

0 comments on commit 49b2b61

Please sign in to comment.