Skip to content

Commit

Permalink
Add strong Wolfe line search for lbfgs (#8824)
Browse files Browse the repository at this point in the history
Summary:
This pull request adds a line search for lbfgs. "strong Wolfe" is the default line search method in [minFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html) and it is also recommended in the [Numerical Optimization](https://www.springer.com/gp/book/9780387303031) book.

The implementation is based on four sources:
+ https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
+ https://www.springer.com/gp/book/9780387303031 Algorithms 3.5, 3.6, formula 3.59
+ https://github.com/torch/optim/blob/master/lswolfe.lua
+ https://github.com/torch/optim/blob/master/polyinterp.lua

The 'lua' version is based on an old version of `minFunc`, which has been updated in 2012. I made a couple of small changes based on the updated version. Due to that, the test of comparing with `.lua` version is not consistent (that's is the reason I changed a learning rate in the test).
Pull Request resolved: #8824

Differential Revision: D15783067

Pulled By: vincentqb

fbshipit-source-id: 5316d9088233981120376d79c7869d5f97e51b69
  • Loading branch information
fehiepsi authored and facebook-github-bot committed Jun 12, 2019
1 parent 2c91ba3 commit ad73ea2
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 21 deletions.
4 changes: 4 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ def test_lbfgs(self):
lambda weight, bias: optim.LBFGS([weight, bias]),
ignore_multidevice=True
)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_Wolfe"),
ignore_multidevice=True
)

@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
def test_lbfgs_return_type(self):
Expand Down
234 changes: 213 additions & 21 deletions torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,171 @@
from .optimizer import Optimizer


def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
# Compute bounds of interpolation area
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1 ** 2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
return min(max(min_pos, xmin_bound), xmax_bound)
else:
return (xmin_bound + xmax_bound) / 2.


def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9,
max_ls=25):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
g = g.clone()
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)

# bracket an interval containing a point satisfying the Wolfe criteria
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
bracket_gtd = [gtd_prev, gtd_new]
break

if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break

if gtd_new >= 0:
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
bracket_gtd = [gtd_prev, gtd_new]
break

# interpolate
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new,
bounds=(min_step, max_step))

# next step
t_prev = tmp
f_prev = f_new
g_prev = g_new.clone()
gtd_prev = gtd_new
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1

# reached max number of iterations?
if ls_iter == max_ls:
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]

# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
while not done and ls_iter < max_ls:
# compute new trial value
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
bracket[1], bracket_f[1], bracket_gtd[1])

# test that we are making sufficient progress:
# in case `t` is so close to boundary, we mark that we are making
# insufficient progress, and if
# + we have made insufficient progress in the last step, or
# + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point.
eps = 0.1 * (max(bracket) - min(bracket))
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
if abs(t - max(bracket)) < abs(t - min(bracket)):
t = max(bracket) - eps
else:
t = min(bracket) + eps
insuf_progress = False
else:
insuf_progress = True
else:
insuf_progress = False

# Evaluate new point
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1

if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone()
bracket_gtd[high_pos] = gtd_new
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos]
bracket_gtd[high_pos] = bracket_gtd[low_pos]

# new point becomes new low
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone()
bracket_gtd[low_pos] = gtd_new

# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
break

# return stuff
t = bracket[low_pos]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos]
return f_new, g_new, t, ls_func_evals


class LBFGS(Optimizer):
"""Implements L-BFGS algorithm.
"""Implements L-BFGS algorithm, heavily inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.
.. warning::
This optimizer doesn't support per-parameter options and parameter
Expand All @@ -30,6 +193,7 @@ class LBFGS(Optimizer):
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_Wolfe' or None (default: None).
"""

def __init__(self, params, lr=1, max_iter=20, max_eval=None,
Expand Down Expand Up @@ -58,11 +222,11 @@ def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.data.new(p.data.numel()).zero_()
elif p.grad.data.is_sparse:
view = p.grad.data.to_dense().view(-1)
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.data.view(-1)
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)

Expand All @@ -75,6 +239,20 @@ def _add_grad(self, step_size, update):
offset += numel
assert offset == self._numel()

def _clone_param(self):
return [p.clone() for p in self._params]

def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.data.copy_(pdata)

def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad

def step(self, closure):
"""Performs a single optimization step.
Expand Down Expand Up @@ -106,16 +284,18 @@ def step(self, closure):
state['func_evals'] += 1

flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
opt_cond = flat_grad.abs().max() <= tolerance_grad

if abs_grad_sum <= tolerance_grad:
# optimal condition
if opt_cond:
return orig_loss

# tensors cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
ro = state.get('ro')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
Expand All @@ -134,6 +314,7 @@ def step(self, closure):
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
Expand All @@ -146,10 +327,12 @@ def step(self, closure):
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)

# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1. / ys)

# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
Expand All @@ -158,15 +341,10 @@ def step(self, closure):
# multiplied by the gradient
num_old = len(old_dirs)

if 'ro' not in state:
state['ro'] = [None] * history_size
if 'al' not in state:
state['al'] = [None] * history_size
ro = state['ro']
al = state['al']

for i in range(num_old):
ro[i] = 1. / old_dirs[i].dot(old_stps[i])

# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
Expand All @@ -191,18 +369,32 @@ def step(self, closure):
############################################################
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / abs_grad_sum) * lr
t = min(1., 1. / flat_grad.abs().sum()) * lr
else:
t = lr

# directional derivative
gtd = flat_grad.dot(d) # g * d

# directional derivative is below tolerance
if gtd > -tolerance_change:
break

# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
raise RuntimeError("line search function is not supported yet")
if line_search_fn != "strong_Wolfe":
raise RuntimeError("only 'strong_Wolfe' is supported")
else:
x_init = self._clone_param()

def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d,
loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
Expand All @@ -212,7 +404,7 @@ def step(self, closure):
# no use to re-evaluate that function here
loss = float(closure())
flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1

# update func eval
Expand All @@ -228,13 +420,12 @@ def step(self, closure):
if current_evals >= max_eval:
break

if abs_grad_sum <= tolerance_grad:
break

if gtd > -tolerance_change:
# optimal condition
if opt_cond:
break

if d.mul(t).abs_().sum() <= tolerance_change:
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:
break

if abs(loss - prev_loss) < tolerance_change:
Expand All @@ -244,6 +435,7 @@ def step(self, closure):
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['ro'] = ro
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
Expand Down

0 comments on commit ad73ea2

Please sign in to comment.