Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/scipy/optimize/line_search.py
2023-06-19 00:49:18 +02:00

428 lines
13 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple, Union
from functools import partial
from jax._src.numpy.util import promote_dtypes_inexact
import jax.numpy as jnp
import jax
from jax import lax
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
def _cubicmin(a, fa, fpa, b, fb, c, fc):
dtype = jnp.result_type(a, fa, fpa, b, fb, c, fc)
C = fpa
db = b - a
dc = c - a
denom = (db * dc) ** 2 * (db - dc)
d1 = jnp.array([[dc ** 2, -db ** 2],
[-dc ** 3, db ** 3]], dtype=dtype)
d2 = jnp.array([fb - fa - C * db, fc - fa - C * dc], dtype=dtype)
A, B = _dot(d1, d2) / denom
radical = B * B - 3. * A * C
xmin = a + (-B + jnp.sqrt(radical)) / (3. * A)
return xmin
def _quadmin(a, fa, fpa, b, fb):
D = fa
C = fpa
db = b - a
B = (fb - D - C * db) / (db ** 2)
xmin = a - C / (2. * B)
return xmin
def _binary_replace(replace_bit, original_dict, new_dict, keys=None):
if keys is None:
keys = new_dict.keys()
out = dict()
for key in keys:
out[key] = jnp.where(replace_bit, new_dict[key], original_dict[key])
return out
class _ZoomState(NamedTuple):
done: Union[bool, jax.Array]
failed: Union[bool, jax.Array]
j: Union[int, jax.Array]
a_lo: Union[float, jax.Array]
phi_lo: Union[float, jax.Array]
dphi_lo: Union[float, jax.Array]
a_hi: Union[float, jax.Array]
phi_hi: Union[float, jax.Array]
dphi_hi: Union[float, jax.Array]
a_rec: Union[float, jax.Array]
phi_rec: Union[float, jax.Array]
a_star: Union[float, jax.Array]
phi_star: Union[float, jax.Array]
dphi_star: Union[float, jax.Array]
g_star: Union[float, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
dphi_lo, a_hi, phi_hi, dphi_hi, g_0, pass_through):
"""
Implementation of zoom. Algorithm 3.6 from Wright and Nocedal, 'Numerical
Optimization', 1999, pg. 59-61. Tries cubic, quadratic, and bisection methods
of zooming.
"""
state = _ZoomState(
done=False,
failed=False,
j=0,
a_lo=a_lo,
phi_lo=phi_lo,
dphi_lo=dphi_lo,
a_hi=a_hi,
phi_hi=phi_hi,
dphi_hi=dphi_hi,
a_rec=(a_lo + a_hi) / 2.,
phi_rec=(phi_lo + phi_hi) / 2.,
a_star=1.,
phi_star=phi_lo,
dphi_star=dphi_lo,
g_star=g_0,
nfev=0,
ngev=0,
)
delta1 = 0.2
delta2 = 0.1
def body(state):
# Body of zoom algorithm. We use boolean arithmetic to avoid using jax.cond
# so that it works on GPU/TPU.
dalpha = (state.a_hi - state.a_lo)
a = jnp.minimum(state.a_hi, state.a_lo)
b = jnp.maximum(state.a_hi, state.a_lo)
cchk = delta1 * dalpha
qchk = delta2 * dalpha
# This will cause the line search to stop, and since the Wolfe conditions
# are not satisfied the minimization should stop too.
threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10)
state = state._replace(failed=state.failed | (dalpha <= threshold))
# Cubmin is sometimes nan, though in this case the bounds check will fail.
a_j_cubic = _cubicmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi,
state.phi_hi, state.a_rec, state.phi_rec)
use_cubic = (state.j > 0) & (a_j_cubic > a + cchk) & (a_j_cubic < b - cchk)
a_j_quad = _quadmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi, state.phi_hi)
use_quad = (~use_cubic) & (a_j_quad > a + qchk) & (a_j_quad < b - qchk)
a_j_bisection = (state.a_lo + state.a_hi) / 2.
use_bisection = (~use_cubic) & (~use_quad)
a_j = jnp.where(use_cubic, a_j_cubic, state.a_rec)
a_j = jnp.where(use_quad, a_j_quad, a_j)
a_j = jnp.where(use_bisection, a_j_bisection, a_j)
# TODO(jakevdp): should we use some sort of fixed-point approach here instead?
phi_j, dphi_j, g_j = restricted_func_and_grad(a_j)
phi_j = phi_j.astype(state.phi_lo.dtype)
dphi_j = dphi_j.astype(state.dphi_lo.dtype)
g_j = g_j.astype(state.g_star.dtype)
state = state._replace(nfev=state.nfev + 1,
ngev=state.ngev + 1)
hi_to_j = wolfe_one(a_j, phi_j) | (phi_j >= state.phi_lo)
star_to_j = wolfe_two(dphi_j) & (~hi_to_j)
hi_to_lo = (dphi_j * (state.a_hi - state.a_lo) >= 0.) & (~hi_to_j) & (~star_to_j)
lo_to_j = (~hi_to_j) & (~star_to_j)
state = state._replace(
**_binary_replace(
hi_to_j,
state._asdict(),
dict(
a_hi=a_j,
phi_hi=phi_j,
dphi_hi=dphi_j,
a_rec=state.a_hi,
phi_rec=state.phi_hi,
),
),
)
# for termination
state = state._replace(
done=star_to_j | state.done,
**_binary_replace(
star_to_j,
state._asdict(),
dict(
a_star=a_j,
phi_star=phi_j,
dphi_star=dphi_j,
g_star=g_j,
)
),
)
state = state._replace(
**_binary_replace(
hi_to_lo,
state._asdict(),
dict(
a_hi=state.a_lo,
phi_hi=state.phi_lo,
dphi_hi=state.dphi_lo,
a_rec=state.a_hi,
phi_rec=state.phi_hi,
),
),
)
state = state._replace(
**_binary_replace(
lo_to_j,
state._asdict(),
dict(
a_lo=a_j,
phi_lo=phi_j,
dphi_lo=dphi_j,
a_rec=state.a_lo,
phi_rec=state.phi_lo,
),
),
)
state = state._replace(j=state.j + 1)
# Choose higher cutoff for maxiter than Scipy as Jax takes longer to find
# the same value - possibly floating point issues?
state = state._replace(failed= state.failed | (state.j >= 30))
return state
state = lax.while_loop(lambda state: (~state.done) & (~pass_through) & (~state.failed),
body,
state)
return state
class _LineSearchState(NamedTuple):
done: Union[bool, jax.Array]
failed: Union[bool, jax.Array]
i: Union[int, jax.Array]
a_i1: Union[float, jax.Array]
phi_i1: Union[float, jax.Array]
dphi_i1: Union[float, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
a_star: Union[float, jax.Array]
phi_star: Union[float, jax.Array]
dphi_star: Union[float, jax.Array]
g_star: jax.Array
class _LineSearchResults(NamedTuple):
"""Results of line search.
Parameters:
failed: True if the strong Wolfe criteria were satisfied
nit: integer number of iterations
nfev: integer number of functions evaluations
ngev: integer number of gradients evaluations
k: integer number of iterations
a_k: integer step size
f_k: final function value
g_k: final gradient value
status: integer end status
"""
failed: Union[bool, jax.Array]
nit: Union[int, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
k: Union[int, jax.Array]
a_k: Union[int, jax.Array]
f_k: jax.Array
g_k: jax.Array
status: Union[bool, jax.Array]
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
c2=0.9, maxiter=20):
"""Inexact line search that satisfies strong Wolfe conditions.
Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61
Args:
fun: function of the form f(x) where x is a flat ndarray and returns a real
scalar. The function should be composed of operations with vjp defined.
x0: initial guess.
pk: direction to search in. Assumes the direction is a descent direction.
old_fval, gfk: initial value of value_and_gradient as position.
old_old_fval: unused argument, only for scipy API compliance.
maxiter: maximum number of iterations to search
c1, c2: Wolfe criteria constant, see ref.
Returns: LineSearchResults
"""
xk, pk = promote_dtypes_inexact(xk, pk)
def restricted_func_and_grad(t):
t = jnp.array(t, dtype=pk.dtype)
phi, g = jax.value_and_grad(f)(xk + t * pk)
dphi = jnp.real(_dot(g, pk))
return phi, dphi, g
if old_fval is None or gfk is None:
phi_0, dphi_0, gfk = restricted_func_and_grad(0)
else:
phi_0 = old_fval
dphi_0 = jnp.real(_dot(gfk, pk))
if old_old_fval is not None:
candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0
start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value)
else:
start_value = 1
def wolfe_one(a_i, phi_i):
# actually negation of W1
return phi_i > phi_0 + c1 * a_i * dphi_0
def wolfe_two(dphi_i):
return jnp.abs(dphi_i) <= -c2 * dphi_0
state = _LineSearchState(
done=False,
failed=False,
# algorithm begins at 1 as per Wright and Nocedal, however Scipy has a
# bug and starts at 0. See https://github.com/scipy/scipy/issues/12157
i=1,
a_i1=0.,
phi_i1=phi_0,
dphi_i1=dphi_0,
nfev=1 if (old_fval is None or gfk is None) else 0,
ngev=1 if (old_fval is None or gfk is None) else 0,
a_star=0.,
phi_star=phi_0,
dphi_star=dphi_0,
g_star=gfk,
)
def body(state):
# no amax in this version, we just double as in scipy.
# unlike original algorithm we do our next choice at the start of this loop
a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.)
phi_i, dphi_i, g_i = restricted_func_and_grad(a_i)
state = state._replace(nfev=state.nfev + 1,
ngev=state.ngev + 1)
star_to_zoom1 = wolfe_one(a_i, phi_i) | ((phi_i >= state.phi_i1) & (state.i > 1))
star_to_i = wolfe_two(dphi_i) & (~star_to_zoom1)
star_to_zoom2 = (dphi_i >= 0.) & (~star_to_zoom1) & (~star_to_i)
zoom1 = _zoom(restricted_func_and_grad,
wolfe_one,
wolfe_two,
state.a_i1,
state.phi_i1,
state.dphi_i1,
a_i,
phi_i,
dphi_i,
gfk,
~star_to_zoom1)
state = state._replace(nfev=state.nfev + zoom1.nfev,
ngev=state.ngev + zoom1.ngev)
zoom2 = _zoom(restricted_func_and_grad,
wolfe_one,
wolfe_two,
a_i,
phi_i,
dphi_i,
state.a_i1,
state.phi_i1,
state.dphi_i1,
gfk,
~star_to_zoom2)
state = state._replace(nfev=state.nfev + zoom2.nfev,
ngev=state.ngev + zoom2.ngev)
state = state._replace(
done=star_to_zoom1 | state.done,
failed=(star_to_zoom1 & zoom1.failed) | state.failed,
**_binary_replace(
star_to_zoom1,
state._asdict(),
zoom1._asdict(),
keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
),
)
state = state._replace(
done=star_to_i | state.done,
**_binary_replace(
star_to_i,
state._asdict(),
dict(
a_star=a_i,
phi_star=phi_i,
dphi_star=dphi_i,
g_star=g_i,
),
),
)
state = state._replace(
done=star_to_zoom2 | state.done,
failed=(star_to_zoom2 & zoom2.failed) | state.failed,
**_binary_replace(
star_to_zoom2,
state._asdict(),
zoom2._asdict(),
keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
),
)
state = state._replace(i=state.i + 1, a_i1=a_i, phi_i1=phi_i, dphi_i1=dphi_i)
return state
state = lax.while_loop(lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed),
body,
state)
status = jnp.where(
state.failed,
jnp.array(1), # zoom failed
jnp.where(
state.i > maxiter,
jnp.array(3), # maxiter reached
jnp.array(0), # passed (should be)
),
)
# Step sizes which are too small causes the optimizer to get stuck with a
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
alpha_k = state.a_star
alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64)
& (jnp.abs(alpha_k) < 1e-8),
jnp.sign(alpha_k) * 1e-8,
alpha_k)
results = _LineSearchResults(
failed=state.failed | (~state.done),
nit=state.i - 1, # because iterations started at 1
nfev=state.nfev,
ngev=state.ngev,
k=state.i,
a_k=alpha_k,
f_k=state.phi_star,
g_k=state.g_star,
status=status,
)
return results