Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/scipy/optimize/_lbfgs.py
2023-06-19 00:49:18 +02:00

236 lines
7.6 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Union
from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from jax._src.scipy.optimize.line_search import line_search
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
Array = Any
class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization
Parameters:
converged: True if minimization converged
failed: True if non-zero status and not converged
k: integer number of iterations of the main loop (optimisation steps)
nfev: integer total number of objective evaluations performed.
ngev: integer total number of jacobian evaluations
x_k: array containing the last argument value found during the search. If
the search converged, then this value is the argmin of the objective
function.
f_k: array containing the value of the objective function at `x_k`. If the
search converged, then this is the (local) minimum of the objective
function.
g_k: array containing the gradient of the objective function at `x_k`. If
the search converged the l2-norm of this tensor should be below the
tolerance.
status: integer describing the status:
0 = nominal , 1 = max iters reached , 2 = max fun evals reached
3 = max grad evals reached , 4 = insufficient progress (ftol)
5 = line search failed
ls_status: integer describing the end status of the last line search
"""
converged: Union[bool, Array]
failed: Union[bool, Array]
k: Union[int, Array]
nfev: Union[int, Array]
ngev: Union[int, Array]
x_k: Array
f_k: Array
g_k: Array
s_history: Array
y_history: Array
rho_history: Array
gamma: Union[float, Array]
status: Union[int, Array]
ls_status: Union[int, Array]
def _minimize_lbfgs(
fun: Callable,
x0: Array,
maxiter: Optional[float] = None,
norm=jnp.inf,
maxcor: int = 10,
ftol: float = 2.220446049250313e-09,
gtol: float = 1e-05,
maxfun: Optional[float] = None,
maxgrad: Optional[float] = None,
maxls: int = 20,
):
"""
Minimize a function using L-BFGS
Implements the L-BFGS algorithm from
Algorithm 7.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 176-185
And generalizes to complex variables from
Sorber, L., Barel, M.V. and Lathauwer, L.D., 2012.
"Unconstrained optimization of real functions in complex variables"
SIAM Journal on Optimization, 22(3), pp.879-898.
Args:
fun: function of the form f(x) where x is a flat ndarray and returns a real scalar.
The function should be composed of operations with vjp defined.
x0: initial guess
maxiter: maximum number of iterations
norm: order of norm for convergence check. Default inf.
maxcor: maximum number of metric corrections ("history size")
ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol`
gtol: terminates the minimization when `|g_k|_norm < gtol`
maxfun: maximum number of function evaluations
maxgrad: maximum number of gradient evaluations
maxls: maximum number of line search steps (per iteration)
Returns:
Optimization results.
"""
d = len(x0)
dtype = jnp.dtype(x0)
# ensure there is at least one termination condition
if (maxiter is None) and (maxfun is None) and (maxgrad is None):
maxiter = d * 200
# set others to inf, such that >= is supported
if maxiter is None:
maxiter = jnp.inf
if maxfun is None:
maxfun = jnp.inf
if maxgrad is None:
maxgrad = jnp.inf
# initial evaluation
f_0, g_0 = jax.value_and_grad(fun)(x0)
state_initial = LBFGSResults(
converged=False,
failed=False,
k=0,
nfev=1,
ngev=1,
x_k=x0,
f_k=f_0,
g_k=g_0,
s_history=jnp.zeros((maxcor, d), dtype=dtype),
y_history=jnp.zeros((maxcor, d), dtype=dtype),
rho_history=jnp.zeros((maxcor,), dtype=dtype),
gamma=1.,
status=0,
ls_status=0,
)
def cond_fun(state: LBFGSResults):
return (~state.converged) & (~state.failed)
def body_fun(state: LBFGSResults):
# find search direction
p_k = _two_loop_recursion(state)
# line search
ls_results = line_search(
f=fun,
xk=state.x_k,
pk=p_k,
old_fval=state.f_k,
gfk=state.g_k,
maxiter=maxls,
)
# evaluate at next iterate
s_k = ls_results.a_k.astype(p_k.dtype) * p_k
x_kp1 = state.x_k + s_k
f_kp1 = ls_results.f_k
g_kp1 = ls_results.g_k
y_k = g_kp1 - state.g_k
rho_k_inv = jnp.real(_dot(y_k, s_k))
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
# replacements for next iteration
status = jnp.array(0)
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore
status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore
status = jnp.where(state.k >= maxiter, 1, status) # type: ignore
status = jnp.where(ls_results.failed, 5, status)
converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
state = state._replace(
converged=converged,
failed=(status > 0) & (~converged),
k=state.k + 1,
nfev=state.nfev + ls_results.nfev,
ngev=state.ngev + ls_results.ngev,
x_k=x_kp1.astype(state.x_k.dtype),
f_k=f_kp1.astype(state.f_k.dtype),
g_k=g_kp1.astype(state.g_k.dtype),
s_history=_update_history_vectors(history=state.s_history, new=s_k),
y_history=_update_history_vectors(history=state.y_history, new=y_k),
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),
gamma=gamma.astype(state.g_k.dtype),
status=jnp.where(converged, 0, status),
ls_status=ls_results.status,
)
return state
return lax.while_loop(cond_fun, body_fun, state_initial)
def _two_loop_recursion(state: LBFGSResults):
dtype = state.rho_history.dtype
his_size = len(state.rho_history)
curr_size = jnp.where(state.k < his_size, state.k, his_size)
q = -jnp.conj(state.g_k)
a_his = jnp.zeros_like(state.rho_history)
def body_fun1(j, carry):
i = his_size - 1 - j
_q, _a_his = carry
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
_a_his = _a_his.at[i].set(a_i)
_q = _q - a_i * jnp.conj(state.y_history[i])
return _q, _a_his
q, a_his = lax.fori_loop(0, curr_size, body_fun1, (q, a_his))
q = state.gamma * q
def body_fun2(j, _q):
i = his_size - curr_size + j
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
_q = _q + (a_his[i] - b_i) * state.s_history[i]
return _q
q = lax.fori_loop(0, curr_size, body_fun2, q)
return q
def _update_history_vectors(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
def _update_history_scalars(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return jnp.roll(history, -1, axis=0).at[-1].set(new)