Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Johnson committed Apr 25, 2018
1 parent 53f0369 commit 00a615f
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 46 deletions.
36 changes: 18 additions & 18 deletions basenet/basenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,34 @@ def train_epoch(self, dataloaders, mode='train', num_batches=np.inf):
if self.verbose:
gen = tqdm(gen, total=len(loader), desc='train_epoch:%s' % mode)

# avg_mom = 0.98
# avg_loss = 0.0
avg_mom = 0.98
avg_loss = 0.0

correct, total, loss_hist = 0, 0, []
for batch_idx, (data, target) in gen:
self.set_progress(self.epoch + batch_idx / len(loader))

output, loss = self.train_batch(data, target)
# loss_hist.append(loss)
loss_hist.append(loss)

# avg_loss = avg_loss * avg_mom + loss * (1 - avg_mom)
# debias_loss = avg_loss / (1 - avg_mom ** (batch_idx + 1))
avg_loss = avg_loss * avg_mom + loss * (1 - avg_mom)
debias_loss = avg_loss / (1 - avg_mom ** (batch_idx + 1))

# correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
# total += data.shape[0]
correct += (to_numpy(output).argmax(axis=1) == to_numpy(target)).sum()
total += data.shape[0]

# if batch_idx > num_batches:
# break
if batch_idx > num_batches:
break

# if self.verbose:
# gen.set_postfix(acc=correct / total)
if self.verbose:
gen.set_postfix(acc=correct / total)

self.epoch += 1
# return {
# "acc" : correct / total,
# "loss" : np.hstack(loss_hist),
# "debias_loss" : debias_loss,
# }
return {
"acc" : float(correct / total),
"loss" : list(map(float, loss_hist)),
"debias_loss" : float(debias_loss),
}

def eval_epoch(self, dataloaders, mode='val', num_batches=np.inf):

Expand Down Expand Up @@ -197,8 +197,8 @@ def eval_epoch(self, dataloaders, mode='val', num_batches=np.inf):
gen.set_postfix(acc=correct / total)

return {
"acc" : correct / total,
"loss" : np.hstack(loss_hist),
"acc" : float(correct / total),
"loss" : list(map(float, loss_hist)),
}

def predict(self, dataloaders, mode='val'):
Expand Down
8 changes: 4 additions & 4 deletions basenet/hp_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def f(progress):
return f

@staticmethod
def linear_cycle(hp_init=0.1, epochs=10, low_hp=0.005, extra=5, **kwargs):
def linear_cycle(hp_max=0.1, epochs=10, low_hp=0.005, extra=5, **kwargs):
def f(progress):
if progress < epochs / 2:
return 2 * hp_init * (1 - float(epochs - progress) / epochs)
return 2 * hp_max * (1 - float(epochs - progress) / epochs)
elif progress <= epochs:
return low_hp + 2 * hp_init * float(epochs - progress) / epochs
return low_hp + 2 * hp_max * float(epochs - progress) / epochs
elif progress <= epochs + extra:
return low_hp * float(extra - (progress - epochs)) / extra
else:
Expand Down Expand Up @@ -266,7 +266,7 @@ def get_optimal_hp(hp_hist, loss_hist, c=10, burnin=5):

# Piecewise linear
hp = HPSchedule.piecewise_linear(breaks=[0, 5, 10, 15], hps=[0, 1, 0.25, 0])
hps = np.vstack([hp(i) for i in np.linspace(-1, 16, 1000)])
hps = np.vstack([hp(i) for i in np.linspace(-1, 40, 1000)])
_ = plt.plot(hps)
show_plot()

Expand Down
12 changes: 6 additions & 6 deletions examples/cifar/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from PIL import Image

from basenet import BaseNet
from basenet.lr import LRSchedule
from basenet.hp_schedule import HPSchedule
from basenet.helpers import to_numpy, set_seeds

import torch
Expand Down Expand Up @@ -163,7 +163,7 @@ def _make_layer(self, in_channels, out_channels, num_blocks, stride):
return nn.Sequential(*layers)

def forward(self, x):
x = self.prep(x.half())
x = self.prep(x)

x = self.layers(x)

Expand All @@ -184,19 +184,19 @@ def forward(self, x):

print('cifar10.py: initializing model...', file=sys.stderr)

model = ResNet18().cuda().half()
model = ResNet18().cuda()
print(model, file=sys.stderr)

# --
# Initialize optimizer

print('cifar10.py: initializing optimizer...', file=sys.stderr)

lr_scheduler = getattr(LRSchedule, args.lr_schedule)(lr_init=args.lr_init, epochs=args.epochs, extra=args.extra)
lr_scheduler = getattr(HPSchedule, args.lr_schedule)(lr_init=args.lr_init, epochs=args.epochs, extra=args.extra)
model.init_optimizer(
opt=torch.optim.SGD,
params=model.parameters(),
lr_scheduler=lr_scheduler,
hp_scheduler={"lr" : lr_scheduler},
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True,
Expand All @@ -212,7 +212,7 @@ def forward(self, x):
test = model.eval_epoch(dataloaders, mode='test')
print(json.dumps({
"epoch" : int(epoch),
"lr" : model.lr,
"lr" : model.hp['lr'],
"test_acc" : float(test['acc']),
"time" : time() - t,
}))
Expand Down
27 changes: 12 additions & 15 deletions examples/cifar_opt/cifar_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,9 @@
# Helpers

def dlib_find_max_global(f, bounds, **kwargs):
print('dlib_find_max_global', file=sys.stderr)

varnames = f.__code__.co_varnames[:f.__code__.co_argcount]
bound1_, bound2_ = [], []
for varname in varnames:
print(varname, bounds[varname][0], bounds[varname][1], file=sys.stderr)

bound1_.append(bounds[varname][0])
bound2_.append(bounds[varname][1])

Expand Down Expand Up @@ -118,7 +114,7 @@ def _make_layer(self, in_channels, out_channels, num_blocks, stride):
return nn.Sequential(*layers)

def forward(self, x):
x = self.prep(x.half())
x = self.prep(x)#.half())

x = self.layers(x)

Expand Down Expand Up @@ -165,8 +161,8 @@ def forward(self, x):
])

try:
trainset = datasets.CIFAR10(root='../data', train=True, download=args.download, transform=transform_train)
testset = datasets.CIFAR10(root='../data', train=False, download=args.download, transform=transform_test)
trainset = datasets.CIFAR10(root='./data', train=True, download=args.download, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=args.download, transform=transform_test)
except:
raise Exception('cifar10.py: error loading data -- try rerunning w/ `--download` flag')

Expand All @@ -193,8 +189,8 @@ def forward(self, x):

def run_one(break1, break2, val1, val2):

try:
set_seeds(args.seed) # Might have bad side effects
# try:
# set_seeds(args.seed) # Might have bad side effects

if (break1 >= break2):
return float(-1)
Expand All @@ -210,7 +206,7 @@ def run_one(break1, break2, val1, val2):
("weight_decay", args.weight_decay),
])

model = ResNet18().cuda().half()
model = ResNet18().cuda()#.half()

lr_scheduler = HPSchedule.piecewise_linear(
breaks=[0, break1, break2, args.epochs],
Expand Down Expand Up @@ -242,16 +238,17 @@ def run_one(break1, break2, val1, val2):
sys.stdout.flush()

return float(test['acc'])
except:
return float(-1)
# except:
# return float(-1)

print('cifar_opt.py: start', file=sys.stderr)
best_args, best_score = dlib_find_max_global(run_one, bounds={
"break1" : (0, 10),
"break2" : (0, 10),
"break1" : (0, args.epochs),
"break2" : (0, args.epochs),
"val1" : (-3, 0),
"val2" : (-3, 0),
}, num_function_calls=100, solver_epsilon=0.001)

print(best_args, file=sys.stderr)
print(best_score, file=sys.stderr)

print('cifar_opt.py: done', file=sys.stderr)
2 changes: 1 addition & 1 deletion examples/cifar_opt/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

# runs.h

python cifar_opt.py --epochs 30 > cifar_opt-30.jl
python cifar_opt.py --epochs 30 > results/cifar_opt-30.jl

# CUDA_VISIBLE_DEVICES=1 python cifar_opt2.py > cifar_opt2.jl
7 changes: 5 additions & 2 deletions examples/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ def smart_json_loads(x):


_ = plt.legend(loc='lower right')
_ = plt.grid(alpha=0.25)
_ = plt.axhline(0.90, c='grey', alpha=0.25, lw=1)
_ = plt.axhline(0.91, c='grey', alpha=0.25, lw=1)
_ = plt.axhline(0.92, c='grey', alpha=0.25, lw=1)
_ = plt.axhline(0.94, c='grey', alpha=0.25, lw=1)
_ = plt.axhline(0.96, c='grey', alpha=0.25, lw=1)
_ = plt.axhline(0.98, c='grey', alpha=0.25, lw=1)
# _ = plt.axhline(0.94, c='grey', alpha=0.5)
_ = plt.ylim(0.5, 1.0)
# _ = plt.ylim(0.9, 1.0)
# _ = plt.xlim(0, 40)
show_plot()

0 comments on commit 00a615f

Please sign in to comment.