828 lines
29 KiB
Python
828 lines
29 KiB
Python
|
import pytest
|
||
|
import numpy as np
|
||
|
from numpy.testing import assert_allclose, assert_equal, assert_array_less
|
||
|
|
||
|
from scipy import stats
|
||
|
import scipy._lib._elementwise_iterative_method as eim
|
||
|
|
||
|
from scipy.optimize._chandrupatla import (_chandrupatla_minimize,
|
||
|
_chandrupatla as _chandrupatla_root)
|
||
|
from scipy.optimize._tstutils import _CHANDRUPATLA_TESTS
|
||
|
|
||
|
from itertools import permutations
|
||
|
from .test_zeros import TestScalarRootFinders
|
||
|
|
||
|
def f1(x):
|
||
|
return 100*(1 - x**3.)**2 + (1-x**2.) + 2*(1-x)**2.
|
||
|
|
||
|
|
||
|
def f2(x):
|
||
|
return 5 + (x - 2.)**6
|
||
|
|
||
|
|
||
|
def f3(x):
|
||
|
return np.exp(x) - 5*x
|
||
|
|
||
|
|
||
|
def f4(x):
|
||
|
return x**5. - 5*x**3. - 20.*x + 5.
|
||
|
|
||
|
|
||
|
def f5(x):
|
||
|
return 8*x**3 - 2*x**2 - 7*x + 3
|
||
|
|
||
|
|
||
|
def _bracket_minimum(func, x1, x2):
|
||
|
phi = 1.61803398875
|
||
|
maxiter = 100
|
||
|
f1 = func(x1)
|
||
|
f2 = func(x2)
|
||
|
step = x2 - x1
|
||
|
x1, x2, f1, f2, step = ((x2, x1, f2, f1, -step) if f2 > f1
|
||
|
else (x1, x2, f1, f2, step))
|
||
|
|
||
|
for i in range(maxiter):
|
||
|
step *= phi
|
||
|
x3 = x2 + step
|
||
|
f3 = func(x3)
|
||
|
if f3 < f2:
|
||
|
x1, x2, f1, f2 = x2, x3, f2, f3
|
||
|
else:
|
||
|
break
|
||
|
return x1, x2, x3, f1, f2, f3
|
||
|
|
||
|
|
||
|
cases = [
|
||
|
(f1, -1, 11),
|
||
|
(f1, -2, 13),
|
||
|
(f1, -4, 13),
|
||
|
(f1, -8, 15),
|
||
|
(f1, -16, 16),
|
||
|
(f1, -32, 19),
|
||
|
(f1, -64, 20),
|
||
|
(f1, -128, 21),
|
||
|
(f1, -256, 21),
|
||
|
(f1, -512, 19),
|
||
|
(f1, -1024, 24),
|
||
|
(f2, -1, 8),
|
||
|
(f2, -2, 6),
|
||
|
(f2, -4, 6),
|
||
|
(f2, -8, 7),
|
||
|
(f2, -16, 8),
|
||
|
(f2, -32, 8),
|
||
|
(f2, -64, 9),
|
||
|
(f2, -128, 11),
|
||
|
(f2, -256, 13),
|
||
|
(f2, -512, 12),
|
||
|
(f2, -1024, 13),
|
||
|
(f3, -1, 11),
|
||
|
(f3, -2, 11),
|
||
|
(f3, -4, 11),
|
||
|
(f3, -8, 10),
|
||
|
(f3, -16, 14),
|
||
|
(f3, -32, 12),
|
||
|
(f3, -64, 15),
|
||
|
(f3, -128, 18),
|
||
|
(f3, -256, 18),
|
||
|
(f3, -512, 19),
|
||
|
(f3, -1024, 19),
|
||
|
(f4, -0.05, 9),
|
||
|
(f4, -0.10, 11),
|
||
|
(f4, -0.15, 11),
|
||
|
(f4, -0.20, 11),
|
||
|
(f4, -0.25, 11),
|
||
|
(f4, -0.30, 9),
|
||
|
(f4, -0.35, 9),
|
||
|
(f4, -0.40, 9),
|
||
|
(f4, -0.45, 10),
|
||
|
(f4, -0.50, 10),
|
||
|
(f4, -0.55, 10),
|
||
|
(f5, -0.05, 6),
|
||
|
(f5, -0.10, 7),
|
||
|
(f5, -0.15, 8),
|
||
|
(f5, -0.20, 10),
|
||
|
(f5, -0.25, 9),
|
||
|
(f5, -0.30, 8),
|
||
|
(f5, -0.35, 7),
|
||
|
(f5, -0.40, 7),
|
||
|
(f5, -0.45, 9),
|
||
|
(f5, -0.50, 9),
|
||
|
(f5, -0.55, 8)
|
||
|
]
|
||
|
|
||
|
|
||
|
class TestChandrupatlaMinimize:
|
||
|
|
||
|
def f(self, x, loc):
|
||
|
dist = stats.norm()
|
||
|
return -dist.pdf(x - loc)
|
||
|
|
||
|
@pytest.mark.parametrize('loc', [0.6, np.linspace(-1.05, 1.05, 10)])
|
||
|
def test_basic(self, loc):
|
||
|
# Find mode of normal distribution. Compare mode against location
|
||
|
# parameter and value of pdf at mode against expected pdf.
|
||
|
res = _chandrupatla_minimize(self.f, -5, 0, 5, args=(loc,))
|
||
|
ref = loc
|
||
|
np.testing.assert_allclose(res.x, ref, rtol=1e-6)
|
||
|
np.testing.assert_allclose(res.fun, -stats.norm.pdf(0), atol=0, rtol=0)
|
||
|
assert res.x.shape == np.shape(ref)
|
||
|
|
||
|
@pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
|
||
|
def test_vectorization(self, shape):
|
||
|
# Test for correct functionality, output shapes, and dtypes for various
|
||
|
# input shapes.
|
||
|
loc = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6
|
||
|
args = (loc,)
|
||
|
|
||
|
@np.vectorize
|
||
|
def chandrupatla_single(loc_single):
|
||
|
return _chandrupatla_minimize(self.f, -5, 0, 5, args=(loc_single,))
|
||
|
|
||
|
def f(*args, **kwargs):
|
||
|
f.f_evals += 1
|
||
|
return self.f(*args, **kwargs)
|
||
|
f.f_evals = 0
|
||
|
|
||
|
res = _chandrupatla_minimize(f, -5, 0, 5, args=args)
|
||
|
refs = chandrupatla_single(loc).ravel()
|
||
|
|
||
|
ref_x = [ref.x for ref in refs]
|
||
|
assert_allclose(res.x.ravel(), ref_x)
|
||
|
assert_equal(res.x.shape, shape)
|
||
|
|
||
|
ref_fun = [ref.fun for ref in refs]
|
||
|
assert_allclose(res.fun.ravel(), ref_fun)
|
||
|
assert_equal(res.fun.shape, shape)
|
||
|
assert_equal(res.fun, self.f(res.x, *args))
|
||
|
|
||
|
ref_success = [ref.success for ref in refs]
|
||
|
assert_equal(res.success.ravel(), ref_success)
|
||
|
assert_equal(res.success.shape, shape)
|
||
|
assert np.issubdtype(res.success.dtype, np.bool_)
|
||
|
|
||
|
ref_flag = [ref.status for ref in refs]
|
||
|
assert_equal(res.status.ravel(), ref_flag)
|
||
|
assert_equal(res.status.shape, shape)
|
||
|
assert np.issubdtype(res.status.dtype, np.integer)
|
||
|
|
||
|
ref_nfev = [ref.nfev for ref in refs]
|
||
|
assert_equal(res.nfev.ravel(), ref_nfev)
|
||
|
assert_equal(np.max(res.nfev), f.f_evals)
|
||
|
assert_equal(res.nfev.shape, res.fun.shape)
|
||
|
assert np.issubdtype(res.nfev.dtype, np.integer)
|
||
|
|
||
|
ref_nit = [ref.nit for ref in refs]
|
||
|
assert_equal(res.nit.ravel(), ref_nit)
|
||
|
assert_equal(np.max(res.nit), f.f_evals-3)
|
||
|
assert_equal(res.nit.shape, res.fun.shape)
|
||
|
assert np.issubdtype(res.nit.dtype, np.integer)
|
||
|
|
||
|
ref_xl = [ref.xl for ref in refs]
|
||
|
assert_allclose(res.xl.ravel(), ref_xl)
|
||
|
assert_equal(res.xl.shape, shape)
|
||
|
|
||
|
ref_xm = [ref.xm for ref in refs]
|
||
|
assert_allclose(res.xm.ravel(), ref_xm)
|
||
|
assert_equal(res.xm.shape, shape)
|
||
|
|
||
|
ref_xr = [ref.xr for ref in refs]
|
||
|
assert_allclose(res.xr.ravel(), ref_xr)
|
||
|
assert_equal(res.xr.shape, shape)
|
||
|
|
||
|
ref_fl = [ref.fl for ref in refs]
|
||
|
assert_allclose(res.fl.ravel(), ref_fl)
|
||
|
assert_equal(res.fl.shape, shape)
|
||
|
assert_allclose(res.fl, self.f(res.xl, *args))
|
||
|
|
||
|
ref_fm = [ref.fm for ref in refs]
|
||
|
assert_allclose(res.fm.ravel(), ref_fm)
|
||
|
assert_equal(res.fm.shape, shape)
|
||
|
assert_allclose(res.fm, self.f(res.xm, *args))
|
||
|
|
||
|
ref_fr = [ref.fr for ref in refs]
|
||
|
assert_allclose(res.fr.ravel(), ref_fr)
|
||
|
assert_equal(res.fr.shape, shape)
|
||
|
assert_allclose(res.fr, self.f(res.xr, *args))
|
||
|
|
||
|
def test_flags(self):
|
||
|
# Test cases that should produce different status flags; show that all
|
||
|
# can be produced simultaneously.
|
||
|
def f(xs, js):
|
||
|
funcs = [lambda x: (x - 2.5) ** 2,
|
||
|
lambda x: x - 10,
|
||
|
lambda x: (x - 2.5) ** 4,
|
||
|
lambda x: np.nan]
|
||
|
|
||
|
return [funcs[j](x) for x, j in zip(xs, js)]
|
||
|
|
||
|
args = (np.arange(4, dtype=np.int64),)
|
||
|
|
||
|
res = _chandrupatla_minimize(f, [0]*4, [2]*4, [np.pi]*4, args=args,
|
||
|
maxiter=10)
|
||
|
|
||
|
ref_flags = np.array([eim._ECONVERGED,
|
||
|
eim._ESIGNERR,
|
||
|
eim._ECONVERR,
|
||
|
eim._EVALUEERR])
|
||
|
assert_equal(res.status, ref_flags)
|
||
|
|
||
|
def test_convergence(self):
|
||
|
# Test that the convergence tolerances behave as expected
|
||
|
rng = np.random.default_rng(2585255913088665241)
|
||
|
p = rng.random(size=3)
|
||
|
bracket = (-5, 0, 5)
|
||
|
args = (p,)
|
||
|
kwargs0 = dict(args=args, xatol=0, xrtol=0, fatol=0, frtol=0)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['xatol'] = 1e-3
|
||
|
res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
j1 = abs(res1.xr - res1.xl)
|
||
|
assert_array_less(j1, 4*kwargs['xatol'])
|
||
|
kwargs['xatol'] = 1e-6
|
||
|
res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
j2 = abs(res2.xr - res2.xl)
|
||
|
assert_array_less(j2, 4*kwargs['xatol'])
|
||
|
assert_array_less(j2, j1)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['xrtol'] = 1e-3
|
||
|
res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
j1 = abs(res1.xr - res1.xl)
|
||
|
assert_array_less(j1, 4*kwargs['xrtol']*abs(res1.x))
|
||
|
kwargs['xrtol'] = 1e-6
|
||
|
res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
j2 = abs(res2.xr - res2.xl)
|
||
|
assert_array_less(j2, 4*kwargs['xrtol']*abs(res2.x))
|
||
|
assert_array_less(j2, j1)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['fatol'] = 1e-3
|
||
|
res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
h1 = abs(res1.fl - 2 * res1.fm + res1.fr)
|
||
|
assert_array_less(h1, 2*kwargs['fatol'])
|
||
|
kwargs['fatol'] = 1e-6
|
||
|
res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
h2 = abs(res2.fl - 2 * res2.fm + res2.fr)
|
||
|
assert_array_less(h2, 2*kwargs['fatol'])
|
||
|
assert_array_less(h2, h1)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['frtol'] = 1e-3
|
||
|
res1 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
h1 = abs(res1.fl - 2 * res1.fm + res1.fr)
|
||
|
assert_array_less(h1, 2*kwargs['frtol']*abs(res1.fun))
|
||
|
kwargs['frtol'] = 1e-6
|
||
|
res2 = _chandrupatla_minimize(self.f, *bracket, **kwargs)
|
||
|
h2 = abs(res2.fl - 2 * res2.fm + res2.fr)
|
||
|
assert_array_less(h2, 2*kwargs['frtol']*abs(res2.fun))
|
||
|
assert_array_less(h2, h1)
|
||
|
|
||
|
def test_maxiter_callback(self):
|
||
|
# Test behavior of `maxiter` parameter and `callback` interface
|
||
|
loc = 0.612814
|
||
|
bracket = (-5, 0, 5)
|
||
|
maxiter = 5
|
||
|
|
||
|
res = _chandrupatla_minimize(self.f, *bracket, args=(loc,),
|
||
|
maxiter=maxiter)
|
||
|
assert not np.any(res.success)
|
||
|
assert np.all(res.nfev == maxiter+3)
|
||
|
assert np.all(res.nit == maxiter)
|
||
|
|
||
|
def callback(res):
|
||
|
callback.iter += 1
|
||
|
callback.res = res
|
||
|
assert hasattr(res, 'x')
|
||
|
if callback.iter == 0:
|
||
|
# callback is called once with initial bracket
|
||
|
assert (res.xl, res.xm, res.xr) == bracket
|
||
|
else:
|
||
|
changed_xr = (res.xl == callback.xl) & (res.xr != callback.xr)
|
||
|
changed_xl = (res.xl != callback.xl) & (res.xr == callback.xr)
|
||
|
assert np.all(changed_xr | changed_xl)
|
||
|
|
||
|
callback.xl = res.xl
|
||
|
callback.xr = res.xr
|
||
|
assert res.status == eim._EINPROGRESS
|
||
|
assert_equal(self.f(res.xl, loc), res.fl)
|
||
|
assert_equal(self.f(res.xm, loc), res.fm)
|
||
|
assert_equal(self.f(res.xr, loc), res.fr)
|
||
|
assert_equal(self.f(res.x, loc), res.fun)
|
||
|
if callback.iter == maxiter:
|
||
|
raise StopIteration
|
||
|
|
||
|
callback.xl = np.nan
|
||
|
callback.xr = np.nan
|
||
|
callback.iter = -1 # callback called once before first iteration
|
||
|
callback.res = None
|
||
|
|
||
|
res2 = _chandrupatla_minimize(self.f, *bracket, args=(loc,),
|
||
|
callback=callback)
|
||
|
|
||
|
# terminating with callback is identical to terminating due to maxiter
|
||
|
# (except for `status`)
|
||
|
for key in res.keys():
|
||
|
if key == 'status':
|
||
|
assert res[key] == eim._ECONVERR
|
||
|
assert callback.res[key] == eim._EINPROGRESS
|
||
|
assert res2[key] == eim._ECALLBACK
|
||
|
else:
|
||
|
assert res2[key] == callback.res[key] == res[key]
|
||
|
|
||
|
@pytest.mark.parametrize('case', cases)
|
||
|
def test_nit_expected(self, case):
|
||
|
# Test that `_chandrupatla` implements Chandrupatla's algorithm:
|
||
|
# in all 55 test cases, the number of iterations performed
|
||
|
# matches the number reported in the original paper.
|
||
|
func, x1, nit = case
|
||
|
|
||
|
# Find bracket using the algorithm in the paper
|
||
|
step = 0.2
|
||
|
x2 = x1 + step
|
||
|
x1, x2, x3, f1, f2, f3 = _bracket_minimum(func, x1, x2)
|
||
|
|
||
|
# Use tolerances from original paper
|
||
|
xatol = 0.0001
|
||
|
fatol = 0.000001
|
||
|
xrtol = 1e-16
|
||
|
frtol = 1e-16
|
||
|
|
||
|
res = _chandrupatla_minimize(func, x1, x2, x3, xatol=xatol,
|
||
|
fatol=fatol, xrtol=xrtol, frtol=frtol)
|
||
|
assert_equal(res.nit, nit)
|
||
|
|
||
|
@pytest.mark.parametrize("loc", (0.65, [0.65, 0.7]))
|
||
|
@pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
|
||
|
def test_dtype(self, loc, dtype):
|
||
|
# Test that dtypes are preserved
|
||
|
|
||
|
loc = dtype(loc)
|
||
|
|
||
|
def f(x, loc):
|
||
|
assert x.dtype == dtype
|
||
|
return ((x - loc) ** 2).astype(dtype)
|
||
|
|
||
|
res = _chandrupatla_minimize(f, dtype(-3), dtype(1), dtype(5),
|
||
|
args=(loc,))
|
||
|
assert res.x.dtype == dtype
|
||
|
assert_allclose(res.x, loc, rtol=np.sqrt(np.finfo(dtype).eps))
|
||
|
|
||
|
def test_input_validation(self):
|
||
|
# Test input validation for appropriate error messages
|
||
|
|
||
|
message = '`func` must be callable.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(None, -4, 0, 4)
|
||
|
|
||
|
message = 'Abscissae and function output must be real numbers.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4+1j, 0, 4)
|
||
|
|
||
|
message = "shape mismatch: objects cannot be broadcast"
|
||
|
# raised by `np.broadcast, but the traceback is readable IMO
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, [-2, -3], [0, 0], [3, 4, 5])
|
||
|
|
||
|
message = "The shape of the array returned by `func` must be the same"
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: [x[0], x[1], x[1]], [-3, -3],
|
||
|
[0, 0], [5, 5])
|
||
|
|
||
|
message = 'Tolerances must be non-negative scalars.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, xatol=-1)
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, xrtol=np.nan)
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, fatol='ekki')
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, frtol=np.nan)
|
||
|
|
||
|
message = '`maxiter` must be a non-negative integer.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, maxiter=1.5)
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, maxiter=-1)
|
||
|
|
||
|
message = '`callback` must be callable.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_minimize(lambda x: x, -4, 0, 4, callback='shrubbery')
|
||
|
|
||
|
def test_bracket_order(self):
|
||
|
# Confirm that order of points in bracket doesn't matter
|
||
|
loc = np.linspace(-1, 1, 6)[:, np.newaxis]
|
||
|
brackets = np.array(list(permutations([-5, 0, 5]))).T
|
||
|
res = _chandrupatla_minimize(self.f, *brackets, args=(loc,))
|
||
|
assert np.all(np.isclose(res.x, loc) | (res.fun == self.f(loc, loc)))
|
||
|
ref = res.x[:, 0] # all columns should be the same
|
||
|
assert_allclose(*np.broadcast_arrays(res.x.T, ref), rtol=1e-15)
|
||
|
|
||
|
def test_special_cases(self):
|
||
|
# Test edge cases and other special cases
|
||
|
|
||
|
# Test that integers are not passed to `f`
|
||
|
# (otherwise this would overflow)
|
||
|
def f(x):
|
||
|
assert np.issubdtype(x.dtype, np.floating)
|
||
|
return (x-1) ** 100
|
||
|
|
||
|
with np.errstate(invalid='ignore'):
|
||
|
res = _chandrupatla_minimize(f, -7, 0, 8, fatol=0, frtol=0)
|
||
|
assert res.success
|
||
|
assert_allclose(res.x, 1, rtol=1e-3)
|
||
|
assert_equal(res.fun, 0)
|
||
|
|
||
|
# Test that if all elements of bracket equal minimizer, algorithm
|
||
|
# reports convergence
|
||
|
def f(x):
|
||
|
return (x-1)**2
|
||
|
|
||
|
res = _chandrupatla_minimize(f, 1, 1, 1)
|
||
|
assert res.success
|
||
|
assert_equal(res.x, 1)
|
||
|
|
||
|
# Test maxiter = 0. Should do nothing to bracket.
|
||
|
def f(x):
|
||
|
return (x-1)**2
|
||
|
|
||
|
bracket = (-3, 1.1, 5)
|
||
|
res = _chandrupatla_minimize(f, *bracket, maxiter=0)
|
||
|
assert res.xl, res.xr == bracket
|
||
|
assert res.nit == 0
|
||
|
assert res.nfev == 3
|
||
|
assert res.status == -2
|
||
|
assert res.x == 1.1 # best so far
|
||
|
|
||
|
# Test scalar `args` (not in tuple)
|
||
|
def f(x, c):
|
||
|
return (x-c)**2 - 1
|
||
|
|
||
|
res = _chandrupatla_minimize(f, -1, 0, 1, args=1/3)
|
||
|
assert_allclose(res.x, 1/3)
|
||
|
|
||
|
# Test zero tolerances
|
||
|
# TODO: fatol/frtol = 0?
|
||
|
def f(x):
|
||
|
return -np.sin(x)
|
||
|
|
||
|
res = _chandrupatla_minimize(f, 0, 1, np.pi, xatol=0, xrtol=0,
|
||
|
fatol=0, frtol=0)
|
||
|
assert res.success
|
||
|
# found a minimum exactly (according to floating point arithmetic)
|
||
|
assert res.xl < res.xm < res.xr
|
||
|
assert f(res.xl) == f(res.xm) == f(res.xr)
|
||
|
|
||
|
|
||
|
class TestChandrupatla(TestScalarRootFinders):
|
||
|
|
||
|
def f(self, q, p):
|
||
|
return stats.norm.cdf(q) - p
|
||
|
|
||
|
@pytest.mark.parametrize('p', [0.6, np.linspace(-0.05, 1.05, 10)])
|
||
|
def test_basic(self, p):
|
||
|
# Invert distribution CDF and compare against distrtibution `ppf`
|
||
|
res = _chandrupatla_root(self.f, -5, 5, args=(p,))
|
||
|
ref = stats.norm().ppf(p)
|
||
|
np.testing.assert_allclose(res.x, ref)
|
||
|
assert res.x.shape == ref.shape
|
||
|
|
||
|
@pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
|
||
|
def test_vectorization(self, shape):
|
||
|
# Test for correct functionality, output shapes, and dtypes for various
|
||
|
# input shapes.
|
||
|
p = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6
|
||
|
args = (p,)
|
||
|
|
||
|
@np.vectorize
|
||
|
def chandrupatla_single(p):
|
||
|
return _chandrupatla_root(self.f, -5, 5, args=(p,))
|
||
|
|
||
|
def f(*args, **kwargs):
|
||
|
f.f_evals += 1
|
||
|
return self.f(*args, **kwargs)
|
||
|
f.f_evals = 0
|
||
|
|
||
|
res = _chandrupatla_root(f, -5, 5, args=args)
|
||
|
refs = chandrupatla_single(p).ravel()
|
||
|
|
||
|
ref_x = [ref.x for ref in refs]
|
||
|
assert_allclose(res.x.ravel(), ref_x)
|
||
|
assert_equal(res.x.shape, shape)
|
||
|
|
||
|
ref_fun = [ref.fun for ref in refs]
|
||
|
assert_allclose(res.fun.ravel(), ref_fun)
|
||
|
assert_equal(res.fun.shape, shape)
|
||
|
assert_equal(res.fun, self.f(res.x, *args))
|
||
|
|
||
|
ref_success = [ref.success for ref in refs]
|
||
|
assert_equal(res.success.ravel(), ref_success)
|
||
|
assert_equal(res.success.shape, shape)
|
||
|
assert np.issubdtype(res.success.dtype, np.bool_)
|
||
|
|
||
|
ref_flag = [ref.status for ref in refs]
|
||
|
assert_equal(res.status.ravel(), ref_flag)
|
||
|
assert_equal(res.status.shape, shape)
|
||
|
assert np.issubdtype(res.status.dtype, np.integer)
|
||
|
|
||
|
ref_nfev = [ref.nfev for ref in refs]
|
||
|
assert_equal(res.nfev.ravel(), ref_nfev)
|
||
|
assert_equal(np.max(res.nfev), f.f_evals)
|
||
|
assert_equal(res.nfev.shape, res.fun.shape)
|
||
|
assert np.issubdtype(res.nfev.dtype, np.integer)
|
||
|
|
||
|
ref_nit = [ref.nit for ref in refs]
|
||
|
assert_equal(res.nit.ravel(), ref_nit)
|
||
|
assert_equal(np.max(res.nit), f.f_evals-2)
|
||
|
assert_equal(res.nit.shape, res.fun.shape)
|
||
|
assert np.issubdtype(res.nit.dtype, np.integer)
|
||
|
|
||
|
ref_xl = [ref.xl for ref in refs]
|
||
|
assert_allclose(res.xl.ravel(), ref_xl)
|
||
|
assert_equal(res.xl.shape, shape)
|
||
|
|
||
|
ref_xr = [ref.xr for ref in refs]
|
||
|
assert_allclose(res.xr.ravel(), ref_xr)
|
||
|
assert_equal(res.xr.shape, shape)
|
||
|
|
||
|
assert_array_less(res.xl, res.xr)
|
||
|
finite = np.isfinite(res.x)
|
||
|
assert np.all((res.x[finite] == res.xl[finite])
|
||
|
| (res.x[finite] == res.xr[finite]))
|
||
|
|
||
|
ref_fl = [ref.fl for ref in refs]
|
||
|
assert_allclose(res.fl.ravel(), ref_fl)
|
||
|
assert_equal(res.fl.shape, shape)
|
||
|
assert_allclose(res.fl, self.f(res.xl, *args))
|
||
|
|
||
|
ref_fr = [ref.fr for ref in refs]
|
||
|
assert_allclose(res.fr.ravel(), ref_fr)
|
||
|
assert_equal(res.fr.shape, shape)
|
||
|
assert_allclose(res.fr, self.f(res.xr, *args))
|
||
|
|
||
|
assert np.all(np.abs(res.fun[finite]) ==
|
||
|
np.minimum(np.abs(res.fl[finite]),
|
||
|
np.abs(res.fr[finite])))
|
||
|
|
||
|
def test_flags(self):
|
||
|
# Test cases that should produce different status flags; show that all
|
||
|
# can be produced simultaneously.
|
||
|
def f(xs, js):
|
||
|
funcs = [lambda x: x - 2.5,
|
||
|
lambda x: x - 10,
|
||
|
lambda x: (x - 0.1)**3,
|
||
|
lambda x: np.nan]
|
||
|
return [funcs[j](x) for x, j in zip(xs, js)]
|
||
|
|
||
|
args = (np.arange(4, dtype=np.int64),)
|
||
|
res = _chandrupatla_root(f, [0]*4, [np.pi]*4, args=args, maxiter=2)
|
||
|
|
||
|
ref_flags = np.array([eim._ECONVERGED,
|
||
|
eim._ESIGNERR,
|
||
|
eim._ECONVERR,
|
||
|
eim._EVALUEERR])
|
||
|
assert_equal(res.status, ref_flags)
|
||
|
|
||
|
def test_convergence(self):
|
||
|
# Test that the convergence tolerances behave as expected
|
||
|
rng = np.random.default_rng(2585255913088665241)
|
||
|
p = rng.random(size=3)
|
||
|
bracket = (-5, 5)
|
||
|
args = (p,)
|
||
|
kwargs0 = dict(args=args, xatol=0, xrtol=0, fatol=0, frtol=0)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['xatol'] = 1e-3
|
||
|
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(res1.xr - res1.xl, 1e-3)
|
||
|
kwargs['xatol'] = 1e-6
|
||
|
res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(res2.xr - res2.xl, 1e-6)
|
||
|
assert_array_less(res2.xr - res2.xl, res1.xr - res1.xl)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['xrtol'] = 1e-3
|
||
|
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(res1.xr - res1.xl, 1e-3 * np.abs(res1.x))
|
||
|
kwargs['xrtol'] = 1e-6
|
||
|
res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(res2.xr - res2.xl, 1e-6 * np.abs(res2.x))
|
||
|
assert_array_less(res2.xr - res2.xl, res1.xr - res1.xl)
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['fatol'] = 1e-3
|
||
|
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(np.abs(res1.fun), 1e-3)
|
||
|
kwargs['fatol'] = 1e-6
|
||
|
res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(np.abs(res2.fun), 1e-6)
|
||
|
assert_array_less(np.abs(res2.fun), np.abs(res1.fun))
|
||
|
|
||
|
kwargs = kwargs0.copy()
|
||
|
kwargs['frtol'] = 1e-3
|
||
|
x1, x2 = bracket
|
||
|
f0 = np.minimum(abs(self.f(x1, *args)), abs(self.f(x2, *args)))
|
||
|
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(np.abs(res1.fun), 1e-3*f0)
|
||
|
kwargs['frtol'] = 1e-6
|
||
|
res2 = _chandrupatla_root(self.f, *bracket, **kwargs)
|
||
|
assert_array_less(np.abs(res2.fun), 1e-6*f0)
|
||
|
assert_array_less(np.abs(res2.fun), np.abs(res1.fun))
|
||
|
|
||
|
def test_maxiter_callback(self):
|
||
|
# Test behavior of `maxiter` parameter and `callback` interface
|
||
|
p = 0.612814
|
||
|
bracket = (-5, 5)
|
||
|
maxiter = 5
|
||
|
|
||
|
def f(q, p):
|
||
|
res = stats.norm().cdf(q) - p
|
||
|
f.x = q
|
||
|
f.fun = res
|
||
|
return res
|
||
|
f.x = None
|
||
|
f.fun = None
|
||
|
|
||
|
res = _chandrupatla_root(f, *bracket, args=(p,),
|
||
|
maxiter=maxiter)
|
||
|
assert not np.any(res.success)
|
||
|
assert np.all(res.nfev == maxiter+2)
|
||
|
assert np.all(res.nit == maxiter)
|
||
|
|
||
|
def callback(res):
|
||
|
callback.iter += 1
|
||
|
callback.res = res
|
||
|
assert hasattr(res, 'x')
|
||
|
if callback.iter == 0:
|
||
|
# callback is called once with initial bracket
|
||
|
assert (res.xl, res.xr) == bracket
|
||
|
else:
|
||
|
changed = (((res.xl == callback.xl) & (res.xr != callback.xr))
|
||
|
| ((res.xl != callback.xl) & (res.xr == callback.xr)))
|
||
|
assert np.all(changed)
|
||
|
|
||
|
callback.xl = res.xl
|
||
|
callback.xr = res.xr
|
||
|
assert res.status == eim._EINPROGRESS
|
||
|
assert_equal(self.f(res.xl, p), res.fl)
|
||
|
assert_equal(self.f(res.xr, p), res.fr)
|
||
|
assert_equal(self.f(res.x, p), res.fun)
|
||
|
if callback.iter == maxiter:
|
||
|
raise StopIteration
|
||
|
callback.iter = -1 # callback called once before first iteration
|
||
|
callback.res = None
|
||
|
callback.xl = None
|
||
|
callback.xr = None
|
||
|
|
||
|
res2 = _chandrupatla_root(f, *bracket, args=(p,),
|
||
|
callback=callback)
|
||
|
|
||
|
# terminating with callback is identical to terminating due to maxiter
|
||
|
# (except for `status`)
|
||
|
for key in res.keys():
|
||
|
if key == 'status':
|
||
|
assert res[key] == eim._ECONVERR
|
||
|
assert callback.res[key] == eim._EINPROGRESS
|
||
|
assert res2[key] == eim._ECALLBACK
|
||
|
else:
|
||
|
assert res2[key] == callback.res[key] == res[key]
|
||
|
|
||
|
@pytest.mark.parametrize('case', _CHANDRUPATLA_TESTS)
|
||
|
def test_nit_expected(self, case):
|
||
|
# Test that `_chandrupatla` implements Chandrupatla's algorithm:
|
||
|
# in all 40 test cases, the number of iterations performed
|
||
|
# matches the number reported in the original paper.
|
||
|
f, bracket, root, nfeval, id = case
|
||
|
# Chandrupatla's criterion is equivalent to
|
||
|
# abs(x2-x1) < 4*abs(xmin)*xrtol + xatol, but we use the more standard
|
||
|
# abs(x2-x1) < abs(xmin)*xrtol + xatol. Therefore, set xrtol to 4x
|
||
|
# that used by Chandrupatla in tests.
|
||
|
res = _chandrupatla_root(f, *bracket, xrtol=4e-10, xatol=1e-5)
|
||
|
assert_allclose(res.fun, f(root), rtol=1e-8, atol=2e-3)
|
||
|
assert_equal(res.nfev, nfeval)
|
||
|
|
||
|
@pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
|
||
|
@pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
|
||
|
def test_dtype(self, root, dtype):
|
||
|
# Test that dtypes are preserved
|
||
|
|
||
|
root = dtype(root)
|
||
|
def f(x, root):
|
||
|
return ((x - root) ** 3).astype(dtype)
|
||
|
|
||
|
res = _chandrupatla_root(f, dtype(-3), dtype(5),
|
||
|
args=(root,), xatol=1e-3)
|
||
|
assert res.x.dtype == dtype
|
||
|
assert np.allclose(res.x, root, atol=1e-3) or np.all(res.fun == 0)
|
||
|
|
||
|
def test_input_validation(self):
|
||
|
# Test input validation for appropriate error messages
|
||
|
|
||
|
message = '`func` must be callable.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(None, -4, 4)
|
||
|
|
||
|
message = 'Abscissae and function output must be real numbers.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4+1j, 4)
|
||
|
|
||
|
message = "shape mismatch: objects cannot be broadcast"
|
||
|
# raised by `np.broadcast, but the traceback is readable IMO
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, [-2, -3], [3, 4, 5])
|
||
|
|
||
|
message = "The shape of the array returned by `func`..."
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: [x[0], x[1], x[1]], [-3, -3], [5, 5])
|
||
|
|
||
|
message = 'Tolerances must be non-negative scalars.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, xatol=-1)
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, xrtol=np.nan)
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, fatol='ekki')
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, frtol=np.nan)
|
||
|
|
||
|
message = '`maxiter` must be a non-negative integer.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, maxiter=1.5)
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, maxiter=-1)
|
||
|
|
||
|
message = '`callback` must be callable.'
|
||
|
with pytest.raises(ValueError, match=message):
|
||
|
_chandrupatla_root(lambda x: x, -4, 4, callback='shrubbery')
|
||
|
|
||
|
def test_special_cases(self):
|
||
|
# Test edge cases and other special cases
|
||
|
|
||
|
# Test that integers are not passed to `f`
|
||
|
# (otherwise this would overflow)
|
||
|
def f(x):
|
||
|
assert np.issubdtype(x.dtype, np.floating)
|
||
|
return x ** 99 - 1
|
||
|
|
||
|
res = _chandrupatla_root(f, -7, 5)
|
||
|
assert res.success
|
||
|
assert_allclose(res.x, 1)
|
||
|
|
||
|
# Test that if both ends of bracket equal root, algorithm reports
|
||
|
# convergence
|
||
|
def f(x):
|
||
|
return x**2 - 1
|
||
|
|
||
|
res = _chandrupatla_root(f, 1, 1)
|
||
|
assert res.success
|
||
|
assert_equal(res.x, 1)
|
||
|
|
||
|
def f(x):
|
||
|
return 1/x
|
||
|
|
||
|
with np.errstate(invalid='ignore'):
|
||
|
res = _chandrupatla_root(f, np.inf, np.inf)
|
||
|
assert res.success
|
||
|
assert_equal(res.x, np.inf)
|
||
|
|
||
|
# Test maxiter = 0. Should do nothing to bracket.
|
||
|
def f(x):
|
||
|
return x**3 - 1
|
||
|
|
||
|
bracket = (-3, 5)
|
||
|
res = _chandrupatla_root(f, *bracket, maxiter=0)
|
||
|
assert res.xl, res.xr == bracket
|
||
|
assert res.nit == 0
|
||
|
assert res.nfev == 2
|
||
|
assert res.status == -2
|
||
|
assert res.x == -3 # best so far
|
||
|
|
||
|
# Test maxiter = 1
|
||
|
res = _chandrupatla_root(f, *bracket, maxiter=1)
|
||
|
assert res.success
|
||
|
assert res.status == 0
|
||
|
assert res.nit == 1
|
||
|
assert res.nfev == 3
|
||
|
assert_allclose(res.x, 1)
|
||
|
|
||
|
# Test scalar `args` (not in tuple)
|
||
|
def f(x, c):
|
||
|
return c*x - 1
|
||
|
|
||
|
res = _chandrupatla_root(f, -1, 1, args=3)
|
||
|
assert_allclose(res.x, 1/3)
|
||
|
|
||
|
# # TODO: Test zero tolerance
|
||
|
# # ~~What's going on here - why are iterations repeated?~~
|
||
|
# # tl goes to zero when xatol=xrtol=0. When function is nearly linear,
|
||
|
# # this causes convergence issues.
|
||
|
# def f(x):
|
||
|
# return np.cos(x)
|
||
|
#
|
||
|
# res = _chandrupatla_root(f, 0, np.pi, xatol=0, xrtol=0)
|
||
|
# assert res.nit < 100
|
||
|
# xp = np.nextafter(res.x, np.inf)
|
||
|
# xm = np.nextafter(res.x, -np.inf)
|
||
|
# assert np.abs(res.fun) < np.abs(f(xp))
|
||
|
# assert np.abs(res.fun) < np.abs(f(xm))
|