Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Add strong Wolfe line search for lbfgs #8824

Closed
wants to merge 9 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init line search
  • Loading branch information
fehiepsi committed Jun 22, 2018
commit fef459a953f49157357ff3b2e10f6445d91ea09f
44 changes: 32 additions & 12 deletions torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,22 @@
from .optimizer import Optimizer


def _interpolate():
pass


def _zoom():
pass


def _strong_Wolfe(phi, t, f, g, gtd, tolerance_change=1e-9, c1=1e-4,
c2=0.9, max_ls=25):



class LBFGS(Optimizer):
"""Implements L-BFGS algorithm.
"""Implements L-BFGS algorithm. This implementation is inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.

.. warning::
This optimizer doesn't support per-parameter options and parameter
Expand Down Expand Up @@ -106,9 +120,10 @@ def step(self, closure):
state['func_evals'] += 1

flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
abs_grad_max = flat_grad.abs().max()

if abs_grad_sum <= tolerance_grad:
# optimal condition
if abs_grad_max <= tolerance_grad:
return orig_loss

# tensors cached in state (for tracing)
Expand All @@ -134,6 +149,7 @@ def step(self, closure):
d = flat_grad.neg()
old_dirs = []
old_stps = []
old_ros = []
H_diag = 1
else:
# do lbfgs update (update memory)
Expand All @@ -146,10 +162,12 @@ def step(self, closure):
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
old_ros.pop(0)

# store new direction/step
old_dirs.append(y)
old_stps.append(s)
old_ros.append(1. / ys)

# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
Expand All @@ -164,8 +182,7 @@ def step(self, closure):
ro = state['ro']
al = state['al']

for i in range(num_old):
ro[i] = 1. / old_dirs[i].dot(old_stps[i])
ro[:num_old] = old_ros

# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
Expand All @@ -191,13 +208,17 @@ def step(self, closure):
############################################################
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / abs_grad_sum) * lr
t = min(1., 1. / flat_grad.abs().sum()) * lr
else:
t = lr

# directional derivative
gtd = flat_grad.dot(d) # g * d

# directional derivative is below tolerance
if gtd > -tolerance_change:
break

This comment was marked as off-topic.


# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
Expand All @@ -212,7 +233,7 @@ def step(self, closure):
# no use to re-evaluate that function here
loss = float(closure())
flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
abs_grad_max = flat_grad.abs().max()
ls_func_evals = 1

# update func eval
Expand All @@ -228,13 +249,12 @@ def step(self, closure):
if current_evals >= max_eval:
break

if abs_grad_sum <= tolerance_grad:
break

if gtd > -tolerance_change:
# optimal condition
if abs_grad_max <= tolerance_grad:
break

if d.mul(t).abs_().sum() <= tolerance_change:
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:

This comment was marked as off-topic.

break

if abs(loss - prev_loss) < tolerance_change:
Expand Down