Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Add strong Wolfe line search for lbfgs #8824

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add strong Wolfe
  • Loading branch information
fehiepsi committed Jun 23, 2018
commit 2f4a29e23302eefe6d38a9f9c5c62fb6bbba2208
2 changes: 1 addition & 1 deletion torch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .rprop import Rprop
from .rmsprop import RMSprop
from .optimizer import Optimizer
from .lbfgs import LBFGS
from .lbfgs import LBFGS, strong_Wolfe
from . import lr_scheduler

del adadelta
Expand Down
174 changes: 158 additions & 16 deletions torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,143 @@


def _interpolate():
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
pass


def _zoom():
pass


def _strong_Wolfe(phi, t, f, g, gtd, tolerance_change=1e-9, c1=1e-4,
c2=0.9, max_ls=25):
def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25):
"""
"""
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
g = g.clone()
d_abs_max = d.abs().max()
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)

# bracket an interval containing a point satisfying the Wolfe criteria
bracket, bracket_f, bracket_g = [None, None], [None, None], [None, None]
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 0 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
bracket_gtd = [gtd_prev, gtd_new]
break

if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break

if gtd_new >= 0 then
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
bracket_gtd = [gtd_prev, gtd_new]
break
vincentqb marked this conversation as resolved.
Show resolved Hide resolved

# interpolate
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new, min_step, max_step)

# next step
t_prev = tmp
f_prev = f_new
g_prev = g_new.clone()
gtd_prev = gtd_new
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1

# reached max number of iterations?
if ls_iter == max_ls:
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]
bracket_gtd = [gtd, gtd_new]

# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
low_pos = 0
while not done and ls_iter < max_iter:
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] < bracket_f[1] else (1, 0)

# compute new trial value
t = _interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
bracket[1], bracket_f[1], bracket_gtd[1])

# test what we are making sufficient progress
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
eps = 0.1 * (max(bracket) - min(bracket))
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
if abs(t - max(bracket)) < abs(t - min(bracket)):
t = max(bracket) - eps
else:
t = min(bracket) + eps
insuf_progress = False
else:
insuf_progress = True
else:
insuf_progress = False

# Evaluate new point
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1

if f_new > (f + c1 * t * gtd) or (f_new >= bracket_f[low_pos]):
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone()
bracket_gtd[high_pos] = gtd_new
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos]
bracket_gtd[high_pos] = bracket_gtd[low_pos]

# new point becomes new low
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone()
bracket_gtd[low_pos] = gtd_new

# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_abs_max < tolerate_change:
break

# return stuff
t = bracket[low_pos]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos]
return f_new, g_new, t, ls_func_evals


class LBFGS(Optimizer):
"""Implements L-BFGS algorithm. This implementation is inspired by `minFunc
"""Implements L-BFGS algorithm, heavily inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.

.. warning::
Expand Down Expand Up @@ -68,15 +191,22 @@ def _numel(self):
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
return self._numel_cache

def _clone_param(self):
return [p.clone() for p in self._params]

def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)

def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.data.new(p.data.numel()).zero_()
elif p.grad.data.is_sparse:
view = p.grad.data.to_dense().view(-1)
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.data.view(-1)
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)

Expand All @@ -89,6 +219,13 @@ def _add_grad(self, step_size, update):
offset += numel
assert offset == self._numel()

def _directional_evaluate(self, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad

def step(self, closure):
"""Performs a single optimization step.

Expand Down Expand Up @@ -120,10 +257,10 @@ def step(self, closure):
state['func_evals'] += 1

flat_grad = self._gather_flat_grad()
abs_grad_max = flat_grad.abs().max()
opt_cond = flat_grad.abs().max() <= tolerance_grad

This comment was marked as off-topic.


# optimal condition
if abs_grad_max <= tolerance_grad:
if opt_cond:
return orig_loss

# tensors cached in state (for tracing)
Expand Down Expand Up @@ -222,8 +359,13 @@ def step(self, closure):
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
x_init = self._clone_param()
# perform line search, using user function
raise RuntimeError("line search function is not supported yet")
t, loss, flat_grad, ls_func_evals = line_search_fn(self._directional_evaluate,
x_init, t, d,
loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
Expand All @@ -233,7 +375,7 @@ def step(self, closure):
# no use to re-evaluate that function here
loss = float(closure())
flat_grad = self._gather_flat_grad()
abs_grad_max = flat_grad.abs().max()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1

# update func eval
Expand All @@ -250,7 +392,7 @@ def step(self, closure):
break

# optimal condition
if abs_grad_max <= tolerance_grad:
if opt_cond:
break

# lack of progress
Expand Down