Inzynierka/Lib/site-packages/scipy/special/tests/test_sf_error.py

123 lines
3.4 KiB
Python
Raw Normal View History

2023-06-02 12:51:02 +02:00
import sys
import warnings
from numpy.testing import assert_, assert_equal, IS_PYPY
import pytest
from pytest import raises as assert_raises
import scipy.special as sc
from scipy.special._ufuncs import _sf_error_test_function
_sf_error_code_map = {
# skip 'ok'
'singular': 1,
'underflow': 2,
'overflow': 3,
'slow': 4,
'loss': 5,
'no_result': 6,
'domain': 7,
'arg': 8,
'other': 9
}
_sf_error_actions = [
'ignore',
'warn',
'raise'
]
def _check_action(fun, args, action):
if action == 'warn':
with pytest.warns(sc.SpecialFunctionWarning):
fun(*args)
elif action == 'raise':
with assert_raises(sc.SpecialFunctionError):
fun(*args)
else:
# action == 'ignore', make sure there are no warnings/exceptions
with warnings.catch_warnings():
warnings.simplefilter("error")
fun(*args)
def test_geterr():
err = sc.geterr()
for key, value in err.items():
assert_(key in _sf_error_code_map)
assert_(value in _sf_error_actions)
def test_seterr():
entry_err = sc.geterr()
try:
for category, error_code in _sf_error_code_map.items():
for action in _sf_error_actions:
geterr_olderr = sc.geterr()
seterr_olderr = sc.seterr(**{category: action})
assert_(geterr_olderr == seterr_olderr)
newerr = sc.geterr()
assert_(newerr[category] == action)
geterr_olderr.pop(category)
newerr.pop(category)
assert_(geterr_olderr == newerr)
_check_action(_sf_error_test_function, (error_code,), action)
finally:
sc.seterr(**entry_err)
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_sf_error_special_refcount():
# Regression test for gh-16233.
# Check that the reference count of scipy.special is not increased
# when a SpecialFunctionError is raised.
refcount_before = sys.getrefcount(sc)
with sc.errstate(all='raise'):
with pytest.raises(sc.SpecialFunctionError, match='domain error'):
sc.ndtri(2.0)
refcount_after = sys.getrefcount(sc)
assert refcount_after == refcount_before
def test_errstate_pyx_basic():
olderr = sc.geterr()
with sc.errstate(singular='raise'):
with assert_raises(sc.SpecialFunctionError):
sc.loggamma(0)
assert_equal(olderr, sc.geterr())
def test_errstate_c_basic():
olderr = sc.geterr()
with sc.errstate(domain='raise'):
with assert_raises(sc.SpecialFunctionError):
sc.spence(-1)
assert_equal(olderr, sc.geterr())
def test_errstate_cpp_basic():
olderr = sc.geterr()
with sc.errstate(underflow='raise'):
with assert_raises(sc.SpecialFunctionError):
sc.wrightomega(-1000)
assert_equal(olderr, sc.geterr())
def test_errstate():
for category, error_code in _sf_error_code_map.items():
for action in _sf_error_actions:
olderr = sc.geterr()
with sc.errstate(**{category: action}):
_check_action(_sf_error_test_function, (error_code,), action)
assert_equal(olderr, sc.geterr())
def test_errstate_all_but_one():
olderr = sc.geterr()
with sc.errstate(all='raise', singular='ignore'):
sc.gammaln(0)
with assert_raises(sc.SpecialFunctionError):
sc.spence(-1.0)
assert_equal(olderr, sc.geterr())