Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Add strong Wolfe line search for lbfgs #8824

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test
  • Loading branch information
fehiepsi committed Jun 23, 2018
commit 441c2c23d700ef8965421601574bfab89802cf2c
8 changes: 6 additions & 2 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,17 @@ def test_lbfgs(self):
wrap_old_fn(old_optim.lbfgs)
)
self._test_rosenbrock(
lambda params: optim.LBFGS(params, lr=5e-2, max_iter=5),
wrap_old_fn(old_optim.lbfgs, learningRate=5e-2, maxIter=5)
lambda params: optim.LBFGS(params, lr=1, max_iter=5),
wrap_old_fn(old_optim.lbfgs, learningRate=1, maxIter=5)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias]),
ignore_multidevice=True
)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_Wolfe"),
ignore_multidevice=True
)

def test_lbfgs_return_type(self):
params = [torch.randn(10, 5), torch.randn(10)]
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .rprop import Rprop
from .rmsprop import RMSprop
from .optimizer import Optimizer
from .lbfgs import LBFGS, strong_Wolfe
from .lbfgs import LBFGS
from . import lr_scheduler

del adadelta
Expand Down
104 changes: 63 additions & 41 deletions torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,51 @@
from .optimizer import Optimizer


def _interpolate():
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
pass
# Compute bounds of interpolation area
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1 ** 2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2*d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2*d2))
return min(max(min_pos, xmin_bound), xmax_bound)
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
else:
return (xmin_bound + xmax_bound) / 2.

This comment was marked as off-topic.

vincentqb marked this conversation as resolved.
Show resolved Hide resolved


def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25):
"""
"""
def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9,
max_ls=25):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
g = g.clone()
d_abs_max = d.abs().max()
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)

# bracket an interval containing a point satisfying the Wolfe criteria
bracket, bracket_f, bracket_g = [None, None], [None, None], [None, None]
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 0 and f_new >= f_prev):
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
Expand All @@ -40,7 +61,7 @@ def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change
done = True
break

if gtd_new >= 0 then
if gtd_new >= 0:
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone()]
Expand All @@ -51,7 +72,8 @@ def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new, min_step, max_step)
t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new,
bounds=(min_step, max_step))

# next step
t_prev = tmp
Expand All @@ -68,20 +90,17 @@ def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]
bracket_gtd = [gtd, gtd_new]

# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
low_pos = 0
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
while not done and ls_iter < max_iter:
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] < bracket_f[1] else (1, 0)

# compute new trial value
t = _interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
bracket[1], bracket_f[1], bracket_gtd[1])
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
bracket[1], bracket_f[1], bracket_gtd[1])

# test what we are making sufficient progress
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
eps = 0.1 * (max(bracket) - min(bracket))
Expand All @@ -105,12 +124,13 @@ def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change
gtd_new = g_new.dot(d)
ls_iter += 1

if f_new > (f + c1 * t * gtd) or (f_new >= bracket_f[low_pos]):
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone()
bracket_gtd[high_pos] = gtd_new
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
Expand All @@ -129,14 +149,14 @@ def strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change
bracket_gtd[low_pos] = gtd_new

# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_abs_max < tolerate_change:
if abs(bracket[1] - bracket[0]) * d_norm < tolerate_change:
break

# return stuff
t = bracket[low_pos]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos]
return f_new, g_new, t, ls_func_evals
return f_new, g_new, t, ls_func_evals


class LBFGS(Optimizer):
Expand Down Expand Up @@ -167,6 +187,7 @@ class LBFGS(Optimizer):
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_Wolfe' or None (default: None).
"""

def __init__(self, params, lr=1, max_iter=20, max_eval=None,
Expand All @@ -191,13 +212,6 @@ def _numel(self):
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
return self._numel_cache

def _clone_param(self):
return [p.clone() for p in self._params]

def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)

def _gather_flat_grad(self):
views = []
for p in self._params:
Expand All @@ -219,7 +233,14 @@ def _add_grad(self, step_size, update):
offset += numel
assert offset == self._numel()

def _directional_evaluate(self, x, t, d):
def _clone_param(self):
return [p.clone() for p in self._params]

def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.data.copy_(pdata)

def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
Expand Down Expand Up @@ -268,6 +289,7 @@ def step(self, closure):
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
ro = state.get('ro')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
Expand All @@ -286,7 +308,7 @@ def step(self, closure):
d = flat_grad.neg()
old_dirs = []
old_stps = []
old_ros = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
Expand All @@ -299,12 +321,12 @@ def step(self, closure):
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
old_ros.pop(0)
ro.pop(0)

# store new direction/step
old_dirs.append(y)
old_stps.append(s)
old_ros.append(1. / ys)
ro.append(1. / ys)

This comment was marked as off-topic.


# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
Expand All @@ -313,14 +335,10 @@ def step(self, closure):
# multiplied by the gradient
num_old = len(old_dirs)

if 'ro' not in state:
state['ro'] = [None] * history_size
if 'al' not in state:
state['al'] = [None] * history_size
ro = state['ro']
al = state['al']

ro[:num_old] = old_ros

# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
Expand Down Expand Up @@ -359,11 +377,14 @@ def step(self, closure):
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
x_init = self._clone_param()
# perform line search, using user function
t, loss, flat_grad, ls_func_evals = line_search_fn(self._directional_evaluate,
x_init, t, d,
loss, flat_grad, gtd)
if line_search_fn != "strong_Wolfe":
raise RuntimeError("only 'strong_Wolfe' is supported")
else:
x_init = self._clone_param()
obj_func = lambda x, t, d: self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d,
loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
Expand Down Expand Up @@ -406,6 +427,7 @@ def step(self, closure):
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['ro'] = ro
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
Expand Down