236 lines
7.6 KiB
Python
236 lines
7.6 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
|
|
from typing import Any, Callable, NamedTuple, Optional, Union
|
|
from functools import partial
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import lax
|
|
from jax._src.scipy.optimize.line_search import line_search
|
|
|
|
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
|
|
|
|
|
Array = Any
|
|
|
|
class LBFGSResults(NamedTuple):
|
|
"""Results from L-BFGS optimization
|
|
|
|
Parameters:
|
|
converged: True if minimization converged
|
|
failed: True if non-zero status and not converged
|
|
k: integer number of iterations of the main loop (optimisation steps)
|
|
nfev: integer total number of objective evaluations performed.
|
|
ngev: integer total number of jacobian evaluations
|
|
x_k: array containing the last argument value found during the search. If
|
|
the search converged, then this value is the argmin of the objective
|
|
function.
|
|
f_k: array containing the value of the objective function at `x_k`. If the
|
|
search converged, then this is the (local) minimum of the objective
|
|
function.
|
|
g_k: array containing the gradient of the objective function at `x_k`. If
|
|
the search converged the l2-norm of this tensor should be below the
|
|
tolerance.
|
|
status: integer describing the status:
|
|
0 = nominal , 1 = max iters reached , 2 = max fun evals reached
|
|
3 = max grad evals reached , 4 = insufficient progress (ftol)
|
|
5 = line search failed
|
|
ls_status: integer describing the end status of the last line search
|
|
"""
|
|
converged: Union[bool, Array]
|
|
failed: Union[bool, Array]
|
|
k: Union[int, Array]
|
|
nfev: Union[int, Array]
|
|
ngev: Union[int, Array]
|
|
x_k: Array
|
|
f_k: Array
|
|
g_k: Array
|
|
s_history: Array
|
|
y_history: Array
|
|
rho_history: Array
|
|
gamma: Union[float, Array]
|
|
status: Union[int, Array]
|
|
ls_status: Union[int, Array]
|
|
|
|
|
|
def _minimize_lbfgs(
|
|
fun: Callable,
|
|
x0: Array,
|
|
maxiter: Optional[float] = None,
|
|
norm=jnp.inf,
|
|
maxcor: int = 10,
|
|
ftol: float = 2.220446049250313e-09,
|
|
gtol: float = 1e-05,
|
|
maxfun: Optional[float] = None,
|
|
maxgrad: Optional[float] = None,
|
|
maxls: int = 20,
|
|
):
|
|
"""
|
|
Minimize a function using L-BFGS
|
|
|
|
Implements the L-BFGS algorithm from
|
|
Algorithm 7.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 176-185
|
|
And generalizes to complex variables from
|
|
Sorber, L., Barel, M.V. and Lathauwer, L.D., 2012.
|
|
"Unconstrained optimization of real functions in complex variables"
|
|
SIAM Journal on Optimization, 22(3), pp.879-898.
|
|
|
|
Args:
|
|
fun: function of the form f(x) where x is a flat ndarray and returns a real scalar.
|
|
The function should be composed of operations with vjp defined.
|
|
x0: initial guess
|
|
maxiter: maximum number of iterations
|
|
norm: order of norm for convergence check. Default inf.
|
|
maxcor: maximum number of metric corrections ("history size")
|
|
ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol`
|
|
gtol: terminates the minimization when `|g_k|_norm < gtol`
|
|
maxfun: maximum number of function evaluations
|
|
maxgrad: maximum number of gradient evaluations
|
|
maxls: maximum number of line search steps (per iteration)
|
|
|
|
Returns:
|
|
Optimization results.
|
|
"""
|
|
d = len(x0)
|
|
dtype = jnp.dtype(x0)
|
|
|
|
# ensure there is at least one termination condition
|
|
if (maxiter is None) and (maxfun is None) and (maxgrad is None):
|
|
maxiter = d * 200
|
|
|
|
# set others to inf, such that >= is supported
|
|
if maxiter is None:
|
|
maxiter = jnp.inf
|
|
if maxfun is None:
|
|
maxfun = jnp.inf
|
|
if maxgrad is None:
|
|
maxgrad = jnp.inf
|
|
|
|
# initial evaluation
|
|
f_0, g_0 = jax.value_and_grad(fun)(x0)
|
|
state_initial = LBFGSResults(
|
|
converged=False,
|
|
failed=False,
|
|
k=0,
|
|
nfev=1,
|
|
ngev=1,
|
|
x_k=x0,
|
|
f_k=f_0,
|
|
g_k=g_0,
|
|
s_history=jnp.zeros((maxcor, d), dtype=dtype),
|
|
y_history=jnp.zeros((maxcor, d), dtype=dtype),
|
|
rho_history=jnp.zeros((maxcor,), dtype=dtype),
|
|
gamma=1.,
|
|
status=0,
|
|
ls_status=0,
|
|
)
|
|
|
|
def cond_fun(state: LBFGSResults):
|
|
return (~state.converged) & (~state.failed)
|
|
|
|
def body_fun(state: LBFGSResults):
|
|
# find search direction
|
|
p_k = _two_loop_recursion(state)
|
|
|
|
# line search
|
|
ls_results = line_search(
|
|
f=fun,
|
|
xk=state.x_k,
|
|
pk=p_k,
|
|
old_fval=state.f_k,
|
|
gfk=state.g_k,
|
|
maxiter=maxls,
|
|
)
|
|
|
|
# evaluate at next iterate
|
|
s_k = ls_results.a_k.astype(p_k.dtype) * p_k
|
|
x_kp1 = state.x_k + s_k
|
|
f_kp1 = ls_results.f_k
|
|
g_kp1 = ls_results.g_k
|
|
y_k = g_kp1 - state.g_k
|
|
rho_k_inv = jnp.real(_dot(y_k, s_k))
|
|
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
|
|
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
|
|
|
|
# replacements for next iteration
|
|
status = jnp.array(0)
|
|
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
|
|
status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore
|
|
status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore
|
|
status = jnp.where(state.k >= maxiter, 1, status) # type: ignore
|
|
status = jnp.where(ls_results.failed, 5, status)
|
|
|
|
converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol
|
|
|
|
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
|
|
state = state._replace(
|
|
converged=converged,
|
|
failed=(status > 0) & (~converged),
|
|
k=state.k + 1,
|
|
nfev=state.nfev + ls_results.nfev,
|
|
ngev=state.ngev + ls_results.ngev,
|
|
x_k=x_kp1.astype(state.x_k.dtype),
|
|
f_k=f_kp1.astype(state.f_k.dtype),
|
|
g_k=g_kp1.astype(state.g_k.dtype),
|
|
s_history=_update_history_vectors(history=state.s_history, new=s_k),
|
|
y_history=_update_history_vectors(history=state.y_history, new=y_k),
|
|
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),
|
|
gamma=gamma.astype(state.g_k.dtype),
|
|
status=jnp.where(converged, 0, status),
|
|
ls_status=ls_results.status,
|
|
)
|
|
|
|
return state
|
|
|
|
return lax.while_loop(cond_fun, body_fun, state_initial)
|
|
|
|
|
|
def _two_loop_recursion(state: LBFGSResults):
|
|
dtype = state.rho_history.dtype
|
|
his_size = len(state.rho_history)
|
|
curr_size = jnp.where(state.k < his_size, state.k, his_size)
|
|
q = -jnp.conj(state.g_k)
|
|
a_his = jnp.zeros_like(state.rho_history)
|
|
|
|
def body_fun1(j, carry):
|
|
i = his_size - 1 - j
|
|
_q, _a_his = carry
|
|
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
|
|
_a_his = _a_his.at[i].set(a_i)
|
|
_q = _q - a_i * jnp.conj(state.y_history[i])
|
|
return _q, _a_his
|
|
|
|
q, a_his = lax.fori_loop(0, curr_size, body_fun1, (q, a_his))
|
|
q = state.gamma * q
|
|
|
|
def body_fun2(j, _q):
|
|
i = his_size - curr_size + j
|
|
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
|
|
_q = _q + (a_his[i] - b_i) * state.s_history[i]
|
|
return _q
|
|
|
|
q = lax.fori_loop(0, curr_size, body_fun2, q)
|
|
return q
|
|
|
|
|
|
def _update_history_vectors(history, new):
|
|
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
|
|
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
|
|
|
|
|
|
def _update_history_scalars(history, new):
|
|
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
|
|
return jnp.roll(history, -1, axis=0).at[-1].set(new)
|