Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Add strong Wolfe line search for lbfgs #8824

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ def test_lbfgs(self):
lambda weight, bias: optim.LBFGS([weight, bias]),
ignore_multidevice=True
)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_Wolfe"),
ignore_multidevice=True
)

@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
def test_lbfgs_return_type(self):
Expand Down
234 changes: 213 additions & 21 deletions torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,171 @@
from .optimizer import Optimizer


def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
# Compute bounds of interpolation area
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1 ** 2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
return min(max(min_pos, xmin_bound), xmax_bound)
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
else:
return (xmin_bound + xmax_bound) / 2.

This comment was marked as off-topic.

vincentqb marked this conversation as resolved.
Show resolved Hide resolved


def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9,
max_ls=25):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
g = g.clone()
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)

# bracket an interval containing a point satisfying the Wolfe criteria
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
bracket_gtd = [gtd_prev, gtd_new]
break

if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break

if gtd_new >= 0:
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
bracket_gtd = [gtd_prev, gtd_new]
break
vincentqb marked this conversation as resolved.
Show resolved Hide resolved

# interpolate
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new,
bounds=(min_step, max_step))

# next step
t_prev = tmp
f_prev = f_new
g_prev = g_new.clone()
gtd_prev = gtd_new
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1

# reached max number of iterations?
if ls_iter == max_ls:
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]

# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
while not done and ls_iter < max_ls:
# compute new trial value
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
bracket[1], bracket_f[1], bracket_gtd[1])

# test that we are making sufficient progress:
# in case `t` is so close to boundary, we mark that we are making
# insufficient progress, and if
# + we have made insufficient progress in the last step, or
# + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point.
eps = 0.1 * (max(bracket) - min(bracket))
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
if abs(t - max(bracket)) < abs(t - min(bracket)):
t = max(bracket) - eps
else:
t = min(bracket) + eps
insuf_progress = False
else:
insuf_progress = True
else:
insuf_progress = False

# Evaluate new point
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1

if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone()
bracket_gtd[high_pos] = gtd_new
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos]
bracket_gtd[high_pos] = bracket_gtd[low_pos]

# new point becomes new low
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone()
bracket_gtd[low_pos] = gtd_new

# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
break

# return stuff
t = bracket[low_pos]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos]
return f_new, g_new, t, ls_func_evals


class LBFGS(Optimizer):
"""Implements L-BFGS algorithm.
"""Implements L-BFGS algorithm, heavily inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.

.. warning::
This optimizer doesn't support per-parameter options and parameter
Expand All @@ -30,6 +193,7 @@ class LBFGS(Optimizer):
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_Wolfe' or None (default: None).
"""

def __init__(self, params, lr=1, max_iter=20, max_eval=None,
Expand Down Expand Up @@ -58,11 +222,11 @@ def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.data.new(p.data.numel()).zero_()
elif p.grad.data.is_sparse:
view = p.grad.data.to_dense().view(-1)
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.data.view(-1)
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)

Expand All @@ -75,6 +239,20 @@ def _add_grad(self, step_size, update):
offset += numel
assert offset == self._numel()

def _clone_param(self):
return [p.clone() for p in self._params]

def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.data.copy_(pdata)

def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad

def step(self, closure):
"""Performs a single optimization step.

Expand Down Expand Up @@ -106,16 +284,18 @@ def step(self, closure):
state['func_evals'] += 1

flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
opt_cond = flat_grad.abs().max() <= tolerance_grad

This comment was marked as off-topic.


if abs_grad_sum <= tolerance_grad:
# optimal condition
if opt_cond:
return orig_loss

# tensors cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
ro = state.get('ro')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
Expand All @@ -134,6 +314,7 @@ def step(self, closure):
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
Expand All @@ -146,10 +327,12 @@ def step(self, closure):
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)

# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1. / ys)

This comment was marked as off-topic.


# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
Expand All @@ -158,15 +341,10 @@ def step(self, closure):
# multiplied by the gradient
num_old = len(old_dirs)

if 'ro' not in state:
state['ro'] = [None] * history_size
if 'al' not in state:
state['al'] = [None] * history_size
ro = state['ro']
al = state['al']

for i in range(num_old):
ro[i] = 1. / old_dirs[i].dot(old_stps[i])

# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
Expand All @@ -191,18 +369,32 @@ def step(self, closure):
############################################################
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / abs_grad_sum) * lr
t = min(1., 1. / flat_grad.abs().sum()) * lr
else:
t = lr

# directional derivative
gtd = flat_grad.dot(d) # g * d

# directional derivative is below tolerance
if gtd > -tolerance_change:
break

This comment was marked as off-topic.


# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
raise RuntimeError("line search function is not supported yet")
if line_search_fn != "strong_Wolfe":
raise RuntimeError("only 'strong_Wolfe' is supported")
else:
x_init = self._clone_param()

def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d,
loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
Expand All @@ -212,7 +404,7 @@ def step(self, closure):
# no use to re-evaluate that function here
loss = float(closure())
flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1

# update func eval
Expand All @@ -228,13 +420,12 @@ def step(self, closure):
if current_evals >= max_eval:
break

if abs_grad_sum <= tolerance_grad:
break

if gtd > -tolerance_change:
# optimal condition
if opt_cond:
break

if d.mul(t).abs_().sum() <= tolerance_change:
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:

This comment was marked as off-topic.

break

if abs(loss - prev_loss) < tolerance_change:
Expand All @@ -244,6 +435,7 @@ def step(self, closure):
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['ro'] = ro
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
Expand Down