import pytest

import numpy as np
from numpy.testing import assert_array_less, assert_allclose, assert_equal

from scipy.optimize._bracket import _bracket_root, _bracket_minimum, _ELIMITS
import scipy._lib._elementwise_iterative_method as eim
from scipy import stats

class TestBracketRoot:
    @pytest.mark.parametrize("seed", (615655101, 3141866013, 238075752))
    @pytest.mark.parametrize("use_xmin", (False, True))
    @pytest.mark.parametrize("other_side", (False, True))
    @pytest.mark.parametrize("fix_one_side", (False, True))
    def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side):
        # Property-based test to confirm that _bracket_root is behaving as
        # expected. The basic case is when root < a < b.
        # The number of times bracket expands (per side) can be found by
        # setting the expression for the left endpoint of the bracket to the
        # root of f (x=0), solving for i, and rounding up. The corresponding
        # lower and upper ends of the bracket are found by plugging this back
        # into the expression for the ends of the bracket.
        # `other_side=True` is the case that a < b < root
        # Special cases like a < root < b are tested separately

        rng = np.random.default_rng(seed)
        xl0, d, factor = rng.random(size=3) * [1e5, 10, 5]
        factor = 1 + factor  # factor must be greater than 1
        xr0 = xl0 + d  # xr0 must be greater than a in basic case

        def f(x):
            f.count += 1
            return x  # root is 0

        if use_xmin:
            xmin = -rng.random()
            n = np.ceil(np.log(-(xl0 - xmin) / xmin) / np.log(factor))
            l, u = xmin + (xl0 - xmin)*factor**-n, xmin + (xl0 - xmin)*factor**-(n - 1)
            kwargs = dict(xl0=xl0, xr0=xr0, factor=factor, xmin=xmin)
        else:
            n = np.ceil(np.log(xr0/d) / np.log(factor))
            l, u = xr0 - d*factor**n, xr0 - d*factor**(n-1)
            kwargs = dict(xl0=xl0, xr0=xr0, factor=factor)

        if other_side:
            kwargs['xl0'], kwargs['xr0'] = -kwargs['xr0'], -kwargs['xl0']
            l, u = -u, -l
            if 'xmin' in kwargs:
                kwargs['xmax'] = -kwargs.pop('xmin')

        if fix_one_side:
            if other_side:
                kwargs['xmin'] = -xr0
            else:
                kwargs['xmax'] = xr0

        f.count = 0
        res = _bracket_root(f, **kwargs)

        # Compare reported number of function evaluations `nfev` against
        # reported `nit`, actual function call count `f.count`, and theoretical
        # number of expansions `n`.
        # When both sides are free, these get multiplied by 2 because function
        # is evaluated on the left and the right each iteration.
        # When one side is fixed, however, we add one: on the right side, the
        # function gets evaluated once at b.
        # Add 1 to `n` and `res.nit` because function evaluations occur at
        # iterations *0*, 1, ..., `n`. Subtract 1 from `f.count` because
        # function is called separately for left and right in iteration 0.
        if not fix_one_side:
            assert res.nfev == 2*(res.nit+1) == 2*(f.count-1) == 2*(n + 1)
        else:
            assert res.nfev == (res.nit+1)+1 == (f.count-1)+1 == (n+1)+1

        # Compare reported bracket to theoretical bracket and reported function
        # values to function evaluated at bracket.
        bracket = np.asarray([res.xl, res.xr])
        assert_allclose(bracket, (l, u))
        f_bracket = np.asarray([res.fl, res.fr])
        assert_allclose(f_bracket, f(bracket))

        # Check that bracket is valid and that status and success are correct
        assert res.xr > res.xl
        signs = np.sign(f_bracket)
        assert signs[0] == -signs[1]
        assert res.status == 0
        assert res.success

    def f(self, q, p):
        return stats.norm.cdf(q) - p

    @pytest.mark.parametrize('p', [0.6, np.linspace(0.05, 0.95, 10)])
    @pytest.mark.parametrize('xmin', [-5, None])
    @pytest.mark.parametrize('xmax', [5, None])
    @pytest.mark.parametrize('factor', [1.2, 2])
    def test_basic(self, p, xmin, xmax, factor):
        # Test basic functionality to bracket root (distribution PPF)
        res = _bracket_root(self.f, -0.01, 0.01, xmin=xmin, xmax=xmax,
                            factor=factor, args=(p,))
        assert_equal(-np.sign(res.fl), np.sign(res.fr))

    @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)])
    def test_vectorization(self, shape):
        # Test for correct functionality, output shapes, and dtypes for various
        # input shapes.
        p = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6
        args = (p,)
        maxiter = 10

        @np.vectorize
        def bracket_root_single(xl0, xr0, xmin, xmax, factor, p):
            return _bracket_root(self.f, xl0, xr0, xmin=xmin, xmax=xmax,
                                 factor=factor, args=(p,),
                                 maxiter=maxiter)

        def f(*args, **kwargs):
            f.f_evals += 1
            return self.f(*args, **kwargs)
        f.f_evals = 0

        rng = np.random.default_rng(2348234)
        xl0 = -rng.random(size=shape)
        xr0 = rng.random(size=shape)
        xmin, xmax = 1e3*xl0, 1e3*xr0
        if shape:  # make some elements un
            i = rng.random(size=shape) > 0.5
            xmin[i], xmax[i] = -np.inf, np.inf
        factor = rng.random(size=shape) + 1.5
        res = _bracket_root(f, xl0, xr0, xmin=xmin, xmax=xmax, factor=factor,
                            args=args, maxiter=maxiter)
        refs = bracket_root_single(xl0, xr0, xmin, xmax, factor, p).ravel()

        attrs = ['xl', 'xr', 'fl', 'fr', 'success', 'nfev', 'nit']
        for attr in attrs:
            ref_attr = [getattr(ref, attr) for ref in refs]
            res_attr = getattr(res, attr)
            assert_allclose(res_attr.ravel(), ref_attr)
            assert_equal(res_attr.shape, shape)

        assert np.issubdtype(res.success.dtype, np.bool_)
        if shape:
            assert np.all(res.success[1:-1])
        assert np.issubdtype(res.status.dtype, np.integer)
        assert np.issubdtype(res.nfev.dtype, np.integer)
        assert np.issubdtype(res.nit.dtype, np.integer)
        assert_equal(np.max(res.nit), f.f_evals - 2)
        assert_array_less(res.xl, res.xr)
        assert_allclose(res.fl, self.f(res.xl, *args))
        assert_allclose(res.fr, self.f(res.xr, *args))

    def test_flags(self):
        # Test cases that should produce different status flags; show that all
        # can be produced simultaneously.
        def f(xs, js):
            funcs = [lambda x: x - 1.5,
                     lambda x: x - 1000,
                     lambda x: x - 1000,
                     lambda x: np.nan]

            return [funcs[j](x) for x, j in zip(xs, js)]

        args = (np.arange(4, dtype=np.int64),)
        res = _bracket_root(f, xl0=[-1, -1, -1, -1], xr0=[1, 1, 1, 1],
                            xmin=[-np.inf, -1, -np.inf, -np.inf],
                            xmax=[np.inf, 1, np.inf, np.inf],
                            args=args, maxiter=3)

        ref_flags = np.array([eim._ECONVERGED,
                              _ELIMITS,
                              eim._ECONVERR,
                              eim._EVALUEERR])
        assert_equal(res.status, ref_flags)

    @pytest.mark.parametrize("root", (0.622, [0.622, 0.623]))
    @pytest.mark.parametrize('xmin', [-5, None])
    @pytest.mark.parametrize('xmax', [5, None])
    @pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
    def test_dtype(self, root, xmin, xmax, dtype):
        # Test that dtypes are preserved

        xmin = xmin if xmin is None else dtype(xmin)
        xmax = xmax if xmax is None else dtype(xmax)
        root = dtype(root)
        def f(x, root):
            return ((x - root) ** 3).astype(dtype)

        bracket = np.asarray([-0.01, 0.01], dtype=dtype)
        res = _bracket_root(f, *bracket, xmin=xmin, xmax=xmax, args=(root,))
        assert np.all(res.success)
        assert res.xl.dtype == res.xr.dtype == dtype
        assert res.fl.dtype == res.fr.dtype == dtype

    def test_input_validation(self):
        # Test input validation for appropriate error messages

        message = '`func` must be callable.'
        with pytest.raises(ValueError, match=message):
            _bracket_root(None, -4, 4)

        message = '...must be numeric and real.'
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4+1j, 4)
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 'hello')
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, xmin=np)
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, xmax=object())
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, factor=sum)

        message = "All elements of `factor` must be greater than 1."
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, factor=0.5)

        message = '`xmin <= xl0 < xr0 <= xmax` must be True'
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, 4, -4)
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, xmax=np.nan)
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, xmin=10)

        message = "shape mismatch: objects cannot be broadcast"
        # raised by `np.broadcast, but the traceback is readable IMO
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, [-2, -3], [3, 4, 5])
        # Consider making this give a more readable error message
        # with pytest.raises(ValueError, match=message):
        #     _bracket_root(lambda x: [x[0], x[1], x[1]], [-3, -3], [5, 5])

        message = '`maxiter` must be a non-negative integer.'
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, maxiter=1.5)
        with pytest.raises(ValueError, match=message):
            _bracket_root(lambda x: x, -4, 4, maxiter=-1)

    def test_special_cases(self):
        # Test edge cases and other special cases

        # Test that integers are not passed to `f`
        # (otherwise this would overflow)
        def f(x):
            assert np.issubdtype(x.dtype, np.floating)
            return x ** 99 - 1

        res = _bracket_root(f, -7, 5)
        assert res.success

        # Test maxiter = 0. Should do nothing to bracket.
        def f(x):
            return x - 10

        bracket = (-3, 5)
        res = _bracket_root(f, *bracket, maxiter=0)
        assert res.xl, res.xr == bracket
        assert res.nit == 0
        assert res.nfev == 2
        assert res.status == -2

        # Test scalar `args` (not in tuple)
        def f(x, c):
            return c*x - 1

        res = _bracket_root(f, -1, 1, args=3)
        assert res.success
        assert_allclose(res.fl, f(res.xl, 3))

        # Test other edge cases

        def f(x):
            f.count += 1
            return x

        # 1. root lies within guess of bracket
        f.count = 0
        _bracket_root(f, -10, 20)
        assert_equal(f.count, 2)

        # 2. bracket endpoint hits root exactly
        f.count = 0
        res = _bracket_root(f, 5, 10, factor=2)
        bracket = (res.xl, res.xr)
        assert_equal(res.nfev, 4)
        assert_allclose(bracket, (0, 5), atol=1e-15)

        # 3. bracket limit hits root exactly
        with np.errstate(over='ignore'):
            res = _bracket_root(f, 5, 10, xmin=0)
        bracket = (res.xl, res.xr)
        assert_allclose(bracket[0], 0, atol=1e-15)
        with np.errstate(over='ignore'):
            res = _bracket_root(f, -10, -5, xmax=0)
        bracket = (res.xl, res.xr)
        assert_allclose(bracket[1], 0, atol=1e-15)

        # 4. bracket not within min, max
        with np.errstate(over='ignore'):
            res = _bracket_root(f, 5, 10, xmin=1)
        assert not res.success


class TestBracketMinimum:
    def init_f(self):
        def f(x, a, b):
            f.count += 1
            return (x - a)**2 + b
        f.count = 0
        return f

    def assert_valid_bracket(self, result):
        assert np.all(
            (result.xl < result.xm) & (result.xm < result.xr)
        )
        assert np.all(
            (result.fl >= result.fm) & (result.fr > result.fm)
            | (result.fl > result.fm) & (result.fr > result.fm)
        )

    def get_kwargs(
            self, *, xl0=None, xr0=None, factor=None, xmin=None, xmax=None, args=()
    ):
        names = ("xl0", "xr0", "xmin", "xmax", "factor", "args")
        return {
            name: val for name, val in zip(names, (xl0, xr0, xmin, xmax, factor, args))
            if isinstance(val, np.ndarray) or np.isscalar(val)
            or val not in [None, ()]
        }

    @pytest.mark.parametrize(
        "seed",
        (
            307448016549685229886351382450158984917,
            11650702770735516532954347931959000479,
            113767103358505514764278732330028568336,
        )
    )
    @pytest.mark.parametrize("use_xmin", (False, True))
    @pytest.mark.parametrize("other_side", (False, True))
    def test_nfev_expected(self, seed, use_xmin, other_side):
        rng = np.random.default_rng(seed)
        args = (0, 0)  # f(x) = x^2 with minimum at 0
        # xl0, xm0, xr0 are chosen such that the initial bracket is to
        # the right of the minimum, and the bracket will expand
        # downhill towards zero.
        xl0, d1, d2, factor = rng.random(size=4) * [1e5, 10, 10, 5]
        xm0 = xl0 + d1
        xr0 = xm0 + d2
        # Factor should be greater than one.
        factor += 1

        if use_xmin:
            xmin = -rng.random() * 5
            n = int(np.ceil(np.log(-(xl0 - xmin) / xmin) / np.log(factor)))
            lower = xmin + (xl0 - xmin)*factor**-n
            middle = xmin + (xl0 - xmin)*factor**-(n-1)
            upper = xmin + (xl0 - xmin)*factor**-(n-2) if n > 1 else xm0
            # It may be the case the lower is below the minimum, but we still
            # don't have a valid bracket.
            if middle**2 > lower**2:
                n += 1
                lower, middle, upper = (
                    xmin + (xl0 - xmin)*factor**-n, lower, middle
                )
        else:
            xmin = None
            n = int(np.ceil(np.log(xl0 / d1) / np.log(factor)))
            lower = xl0 - d1*factor**n
            middle = xl0 - d1*factor**(n-1) if n > 1 else xl0
            upper = xl0 - d1*factor**(n-2) if n > 1 else xm0
            # It may be the case the lower is below the minimum, but we still
            # don't have a valid bracket.
            if middle**2 > lower**2:
                n += 1
                lower, middle, upper = (
                    xl0 - d1*factor**n, lower, middle
                )
        f = self.init_f()

        xmax = None
        if other_side:
            xl0, xm0, xr0 = -xr0, -xm0, -xl0
            xmin, xmax = None, -xmin if xmin is not None else None
            lower, middle, upper = -upper, -middle, -lower

        kwargs = self.get_kwargs(
            xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, factor=factor, args=args
        )
        result = _bracket_minimum(f, xm0, **kwargs)

        # Check that `nfev` and `nit` have the correct relationship
        assert result.nfev == result.nit + 3
        # Check that `nfev` reports the correct number of function evaluations.
        assert result.nfev == f.count
        # Check that the number of iterations matches the theoretical value.
        assert result.nit == n

        # Compare reported bracket to theoretical bracket and reported function
        # values to function evaluated at bracket.
        bracket = np.asarray([result.xl, result.xm, result.xr])
        assert_allclose(bracket, (lower, middle, upper))
        f_bracket = np.asarray([result.fl, result.fm, result.fr])
        assert_allclose(f_bracket, f(bracket, *args))

        self.assert_valid_bracket(result)
        assert result.status == 0
        assert result.success

    def test_flags(self):
        # Test cases that should produce different status flags; show that all
        # can be produced simultaneously
        def f(xs, js):
            funcs = [lambda x: (x - 1.5)**2,
                     lambda x: x,
                     lambda x: x,
                     lambda x: np.nan]
            return [funcs[j](x) for x, j in zip(xs, js)]

        args = (np.arange(4, dtype=np.int64),)
        xl0, xm0, xr0 = np.full(4, -1.0), np.full(4, 0.0), np.full(4, 1.0)
        result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0,
                                  xmin=[-np.inf, -1.0, -np.inf, -np.inf],
                                  args=args, maxiter=3)

        reference_flags = np.array([eim._ECONVERGED, _ELIMITS,
                                    eim._ECONVERR, eim._EVALUEERR])
        assert_equal(result.status, reference_flags)

    @pytest.mark.parametrize("minimum", (0.622, [0.622, 0.623]))
    @pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
    @pytest.mark.parametrize("xmin", [-5, None])
    @pytest.mark.parametrize("xmax", [5, None])
    def test_dtypes(self, minimum, xmin, xmax, dtype):
        xmin = xmin if xmin is None else dtype(xmin)
        xmax = xmax if xmax is None else dtype(xmax)
        minimum = dtype(minimum)

        def f(x, minimum):
            return ((x - minimum)**2).astype(dtype)

        xl0, xm0, xr0 = np.array([-0.01, 0.0, 0.01], dtype=dtype)
        result = _bracket_minimum(
            f, xm0, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=(minimum, )
        )
        assert np.all(result.success)
        assert result.xl.dtype == result.xm.dtype == result.xr.dtype == dtype
        assert result.fl.dtype == result.fm.dtype == result.fr.dtype == dtype

    def test_input_validation(self):
        # Test input validation for appropriate error messages

        message = '`func` must be callable.'
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(None, -4, xl0=4)

        message = '...must be numeric and real.'
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, 4+1j)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xl0='hello')
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xmin=np)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xmax=object())
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, factor=sum)

        message = "All elements of `factor` must be greater than 1."
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x, -4, factor=0.5)

        message = '`xmin <= xl0 < xm0 < xr0 <= xmax` must be True'
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, 4, xl0=6)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xr0=-6)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xl0=-3, xr0=-2)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xl0=-6, xr0=-5)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xl0=-np.nan)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xr0=np.nan)

        message = "shape mismatch: objects cannot be broadcast"
        # raised by `np.broadcast, but the traceback is readable IMO
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, [-2, -3], xl0=[-3, -4, -5])

        message = '`maxiter` must be a non-negative integer.'
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xr0=4, maxiter=1.5)
        with pytest.raises(ValueError, match=message):
            _bracket_minimum(lambda x: x**2, -4, xr0=4, maxiter=-1)

    @pytest.mark.parametrize("xl0", [0.0, None])
    @pytest.mark.parametrize("xm0", (0.05, 0.1, 0.15))
    @pytest.mark.parametrize("xr0", (0.2, 0.4, 0.6, None))
    # Minimum is ``a`` for each tuple ``(a, b)`` below. Tests cases where minimum
    # is within, or at varying disances to the left or right of the initial
    # bracket.
    @pytest.mark.parametrize(
        "args",
        (
            (1.2, 0), (-0.5, 0), (0.1, 0), (0.2, 0), (3.6, 0), (21.4, 0),
            (121.6, 0), (5764.1, 0), (-6.4, 0), (-12.9, 0), (-146.2, 0)
        )
    )
    def test_scalar_no_limits(self, xl0, xm0, xr0, args):
        f = self.init_f()
        kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, args=args)
        result = _bracket_minimum(f, xm0, **kwargs)
        self.assert_valid_bracket(result)
        assert result.status == 0
        assert result.success
        assert result.nfev == f.count

    @pytest.mark.parametrize(
        # xmin is set at 0.0 in all cases.
        "xl0,xm0,xr0,xmin",
        (
            # Initial bracket at varying distances from the xmin.
            (0.5, 0.75, 1.0, 0.0),
            (1.0, 2.5, 4.0, 0.0),
            (2.0, 4.0, 6.0, 0.0),
            (12.0, 16.0, 20.0, 0.0),
            # Test default initial left endpoint selection. It should not
            # be below xmin.
            (None, 0.75, 1.0, 0.0),
            (None, 2.5, 4.0, 0.0),
            (None, 4.0, 6.0, 0.0),
            (None, 16.0, 20.0, 0.0),
        )
    )
    @pytest.mark.parametrize(
        "args", (
            (0.0, 0.0), # Minimum is directly at xmin.
            (1e-300, 0.0), # Minimum is extremely close to xmin.
            (1e-20, 0.0), # Minimum is very close to xmin.
            # Minimum at varying distances from xmin.
            (0.1, 0.0),
            (0.2, 0.0),
            (0.4, 0.0)
        )
    )
    def test_scalar_with_limit_left(self, xl0, xm0, xr0, xmin, args):
        f = self.init_f()
        kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmin=xmin, args=args)
        result = _bracket_minimum(f, xm0, **kwargs)
        self.assert_valid_bracket(result)
        assert result.status == 0
        assert result.success
        assert result.nfev == f.count

    @pytest.mark.parametrize(
        #xmax is set to 1.0 in all cases.
        "xl0,xm0,xr0,xmax",
        (
            # Bracket at varying distances from xmax.
            (0.2, 0.3, 0.4, 1.0),
            (0.05, 0.075, 0.1, 1.0),
            (-0.2, -0.1, 0.0, 1.0),
            (-21.2, -17.7, -14.2, 1.0),
            # Test default right endpoint selection. It should not exceed xmax.
            (0.2, 0.3, None, 1.0),
            (0.05, 0.075, None, 1.0),
            (-0.2, -0.1, None, 1.0),
            (-21.2, -17.7, None, 1.0),
        )
    )
    @pytest.mark.parametrize(
        "args", (
            (0.9999999999999999, 0.0), # Minimum very close to xmax.
            # Minimum at varying distances from xmax.
            (0.9, 0.0),
            (0.7, 0.0),
            (0.5, 0.0)
        )
    )
    def test_scalar_with_limit_right(self, xl0, xm0, xr0, xmax, args):
        f = self.init_f()
        kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmax=xmax, args=args)
        result = _bracket_minimum(f, xm0, **kwargs)
        self.assert_valid_bracket(result)
        assert result.status == 0
        assert result.success
        assert result.nfev == f.count

    @pytest.mark.parametrize(
        "xl0,xm0,xr0,xmin,xmax,args",
        (
            (   # Case 1:
                # Initial bracket.
                0.2, 
                0.3,
                0.4,
                # Function slopes down to the right from the bracket to a minimum
                # at 1.0. xmax is also at 1.0
                None, 
                1.0,
                (1.0, 0.0)
            ),
            (   # Case 2:
                # Initial bracket.
                1.4,
                1.95,
                2.5,
                # Function slopes down to the left from the bracket to a minimum at
                # 0.3 with xmin set to 0.3.
                0.3,
                None,
                (0.3, 0.0)
            ),
            (
                # Case 3:
                # Initial bracket.
                2.6,
                3.25,
                3.9,
                # Function slopes down and to the right to a minimum at 99.4 with xmax
                # at 99.4. Tests case where minimum is at xmax relatively further from
                # the bracket.
                None,
                99.4,
                (99.4, 0)
            ),
            (
                # Case 4:
                # Initial bracket.
                4,
                4.5,
                5,
                # Function slopes down and to the left away from the bracket with a
                # minimum at -26.3 with xmin set to -26.3. Tests case where minimum is
                # at xmin relatively far from the bracket.
                -26.3,
                None,
                (-26.3, 0)
            ),
            (
                # Case 5:
                # Similar to Case 1 above, but tests default values of xl0 and xr0.
                None,
                0.3,
                None,
                None,
                1.0,
                (1.0, 0.0)
            ),
            (   # Case 6:
                # Similar to Case 2 above, but tests default values of xl0 and xr0.
                None,
                1.95,
                None,
                0.3,
                None,
                (0.3, 0.0)
            ),
            (
                # Case 7:
                # Similar to Case 3 above, but tests default values of xl0 and xr0.
                None,
                3.25,
                None,
                None,
                99.4,
                (99.4, 0)
            ),
            (
                # Case 8:
                # Similar to Case 4 above, but tests default values of xl0 and xr0.
                None,
                4.5,
                None,
                -26.3,
                None,
                (-26.3, 0)
            ),
        )
    )
    def test_minimum_at_boundary_point(self, xl0, xm0, xr0, xmin, xmax, args):
        f = self.init_f()
        kwargs = self.get_kwargs(xr0=xr0, xmin=xmin, xmax=xmax, args=args)
        result = _bracket_minimum(f, xm0, **kwargs)
        assert result.status == -1
        assert args[0] in (result.xl, result.xr)
        assert result.nfev == f.count

    @pytest.mark.parametrize('shape', [tuple(), (12, ), (3, 4), (3, 2, 2)])
    def test_vectorization(self, shape):
        # Test for correct functionality, output shapes, and dtypes for
        # various input shapes.
        a = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6
        args = (a, 0.0)
        maxiter = 10

        @np.vectorize
        def bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a):
            return _bracket_minimum(self.init_f(), xm0, xl0=xl0, xr0=xr0, xmin=xmin,
                                    xmax=xmax, factor=factor, maxiter=maxiter,
                                    args=(a, 0.0))

        f = self.init_f()

        rng = np.random.default_rng(2348234)
        xl0 = -rng.random(size=shape)
        xr0 = rng.random(size=shape)
        xm0 = xl0 + rng.random(size=shape) * (xr0 - xl0)
        xmin, xmax = 1e3*xl0, 1e3*xr0
        if shape:  # make some elements un
            i = rng.random(size=shape) > 0.5
            xmin[i], xmax[i] = -np.inf, np.inf
        factor = rng.random(size=shape) + 1.5
        res = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax,
                               factor=factor, args=args, maxiter=maxiter)
        refs = bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a).ravel()

        attrs = ['xl', 'xm', 'xr', 'fl', 'fm', 'fr', 'success', 'nfev', 'nit']
        for attr in attrs:
            ref_attr = [getattr(ref, attr) for ref in refs]
            res_attr = getattr(res, attr)
            assert_allclose(res_attr.ravel(), ref_attr)
            assert_equal(res_attr.shape, shape)

        assert np.issubdtype(res.success.dtype, np.bool_)
        if shape:
            assert np.all(res.success[1:-1])
        assert np.issubdtype(res.status.dtype, np.integer)
        assert np.issubdtype(res.nfev.dtype, np.integer)
        assert np.issubdtype(res.nit.dtype, np.integer)
        assert_equal(np.max(res.nit), f.count - 3)
        self.assert_valid_bracket(res)
        assert_allclose(res.fl, f(res.xl, *args))
        assert_allclose(res.fm, f(res.xm, *args))
        assert_allclose(res.fr, f(res.xr, *args))

    def test_special_cases(self):
        # Test edge cases and other special cases.

        # Test that integers are not passed to `f`
        # (otherwise this would overflow)
        def f(x):
            assert np.issubdtype(x.dtype, np.floating)
            return x ** 98 - 1

        result = _bracket_minimum(f, -7, xr0=5)
        assert result.success

        # Test maxiter = 0. Should do nothing to bracket.
        def f(x):
            return x**2 - 10

        xl0, xm0, xr0 = -3, -1, 2
        result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, maxiter=0)
        assert_equal([result.xl, result.xm, result.xr], [xl0, xm0, xr0])

        # Test scalar `args` (not in tuple)
        def f(x, c):
            return c*x**2 - 1

        result = _bracket_minimum(f, -1, args=3)
        assert result.success
        assert_allclose(result.fl, f(result.xl, 3))

        # Initial bracket is valid.
        f = self.init_f()
        xl0, xm0, xr0 = [-1.0, -0.2, 1.0]
        args = (0, 0)
        result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, args=args)
        assert f.count == 3

        assert_equal(
            [result.xl, result.xm, result.xr],
            [xl0, xm0, xr0],
        )
        assert_equal(
            [result.fl, result.fm, result.fr],
            [f(xl0, *args), f(xm0, *args), f(xr0, *args)],
        )