Inzynierka/Lib/site-packages/scipy/stats/tests/common_tests.py
2023-06-02 12:51:02 +02:00

451 lines
15 KiB
Python

import pickle
import re
import numpy as np
import numpy.testing as npt
from numpy.testing import assert_allclose, assert_equal
from pytest import raises as assert_raises
import numpy.ma.testutils as ma_npt
from scipy._lib._util import getfullargspec_no_self as _getfullargspec
from scipy import stats
def check_named_results(res, attributes, ma=False):
for i, attr in enumerate(attributes):
if ma:
ma_npt.assert_equal(res[i], getattr(res, attr))
else:
npt.assert_equal(res[i], getattr(res, attr))
def check_normalization(distfn, args, distname):
norm_moment = distfn.moment(0, *args)
npt.assert_allclose(norm_moment, 1.0)
if distname == "rv_histogram_instance":
atol, rtol = 1e-5, 0
else:
atol, rtol = 1e-7, 1e-7
normalization_expect = distfn.expect(lambda x: 1, args=args)
npt.assert_allclose(normalization_expect, 1.0, atol=atol, rtol=rtol,
err_msg=distname, verbose=True)
_a, _b = distfn.support(*args)
normalization_cdf = distfn.cdf(_b, *args)
npt.assert_allclose(normalization_cdf, 1.0)
def check_moment(distfn, arg, m, v, msg):
m1 = distfn.moment(1, *arg)
m2 = distfn.moment(2, *arg)
if not np.isinf(m):
npt.assert_almost_equal(m1, m, decimal=10, err_msg=msg +
' - 1st moment')
else: # or np.isnan(m1),
npt.assert_(np.isinf(m1),
msg + ' - 1st moment -infinite, m1=%s' % str(m1))
if not np.isinf(v):
npt.assert_almost_equal(m2 - m1 * m1, v, decimal=10, err_msg=msg +
' - 2ndt moment')
else: # or np.isnan(m2),
npt.assert_(np.isinf(m2),
msg + ' - 2nd moment -infinite, m2=%s' % str(m2))
def check_mean_expect(distfn, arg, m, msg):
if np.isfinite(m):
m1 = distfn.expect(lambda x: x, arg)
npt.assert_almost_equal(m1, m, decimal=5, err_msg=msg +
' - 1st moment (expect)')
def check_var_expect(distfn, arg, m, v, msg):
kwargs = {'rtol': 5e-6} if msg == "rv_histogram_instance" else {}
if np.isfinite(v):
m2 = distfn.expect(lambda x: x*x, arg)
npt.assert_allclose(m2, v + m*m, **kwargs)
def check_skew_expect(distfn, arg, m, v, s, msg):
if np.isfinite(s):
m3e = distfn.expect(lambda x: np.power(x-m, 3), arg)
npt.assert_almost_equal(m3e, s * np.power(v, 1.5),
decimal=5, err_msg=msg + ' - skew')
else:
npt.assert_(np.isnan(s))
def check_kurt_expect(distfn, arg, m, v, k, msg):
if np.isfinite(k):
m4e = distfn.expect(lambda x: np.power(x-m, 4), arg)
npt.assert_allclose(m4e, (k + 3.) * np.power(v, 2), atol=1e-5, rtol=1e-5,
err_msg=msg + ' - kurtosis')
elif not np.isposinf(k):
npt.assert_(np.isnan(k))
def check_entropy(distfn, arg, msg):
ent = distfn.entropy(*arg)
npt.assert_(not np.isnan(ent), msg + 'test Entropy is nan')
def check_private_entropy(distfn, args, superclass):
# compare a generic _entropy with the distribution-specific implementation
npt.assert_allclose(distfn._entropy(*args),
superclass._entropy(distfn, *args))
def check_entropy_vect_scale(distfn, arg):
# check 2-d
sc = np.asarray([[1, 2], [3, 4]])
v_ent = distfn.entropy(*arg, scale=sc)
s_ent = [distfn.entropy(*arg, scale=s) for s in sc.ravel()]
s_ent = np.asarray(s_ent).reshape(v_ent.shape)
assert_allclose(v_ent, s_ent, atol=1e-14)
# check invalid value, check cast
sc = [1, 2, -3]
v_ent = distfn.entropy(*arg, scale=sc)
s_ent = [distfn.entropy(*arg, scale=s) for s in sc]
s_ent = np.asarray(s_ent).reshape(v_ent.shape)
assert_allclose(v_ent, s_ent, atol=1e-14)
def check_edge_support(distfn, args):
# Make sure that x=self.a and self.b are handled correctly.
x = distfn.support(*args)
if isinstance(distfn, stats.rv_discrete):
x = x[0]-1, x[1]
npt.assert_equal(distfn.cdf(x, *args), [0.0, 1.0])
npt.assert_equal(distfn.sf(x, *args), [1.0, 0.0])
if distfn.name not in ('skellam', 'dlaplace'):
# with a = -inf, log(0) generates warnings
npt.assert_equal(distfn.logcdf(x, *args), [-np.inf, 0.0])
npt.assert_equal(distfn.logsf(x, *args), [0.0, -np.inf])
npt.assert_equal(distfn.ppf([0.0, 1.0], *args), x)
npt.assert_equal(distfn.isf([0.0, 1.0], *args), x[::-1])
# out-of-bounds for isf & ppf
npt.assert_(np.isnan(distfn.isf([-1, 2], *args)).all())
npt.assert_(np.isnan(distfn.ppf([-1, 2], *args)).all())
def check_named_args(distfn, x, shape_args, defaults, meths):
## Check calling w/ named arguments.
# check consistency of shapes, numargs and _parse signature
signature = _getfullargspec(distfn._parse_args)
npt.assert_(signature.varargs is None)
npt.assert_(signature.varkw is None)
npt.assert_(not signature.kwonlyargs)
npt.assert_(list(signature.defaults) == list(defaults))
shape_argnames = signature.args[:-len(defaults)] # a, b, loc=0, scale=1
if distfn.shapes:
shapes_ = distfn.shapes.replace(',', ' ').split()
else:
shapes_ = ''
npt.assert_(len(shapes_) == distfn.numargs)
npt.assert_(len(shapes_) == len(shape_argnames))
# check calling w/ named arguments
shape_args = list(shape_args)
vals = [meth(x, *shape_args) for meth in meths]
npt.assert_(np.all(np.isfinite(vals)))
names, a, k = shape_argnames[:], shape_args[:], {}
while names:
k.update({names.pop(): a.pop()})
v = [meth(x, *a, **k) for meth in meths]
npt.assert_array_equal(vals, v)
if 'n' not in k.keys():
# `n` is first parameter of moment(), so can't be used as named arg
npt.assert_equal(distfn.moment(1, *a, **k),
distfn.moment(1, *shape_args))
# unknown arguments should not go through:
k.update({'kaboom': 42})
assert_raises(TypeError, distfn.cdf, x, **k)
def check_random_state_property(distfn, args):
# check the random_state attribute of a distribution *instance*
# This test fiddles with distfn.random_state. This breaks other tests,
# hence need to save it and then restore.
rndm = distfn.random_state
# baseline: this relies on the global state
np.random.seed(1234)
distfn.random_state = None
r0 = distfn.rvs(*args, size=8)
# use an explicit instance-level random_state
distfn.random_state = 1234
r1 = distfn.rvs(*args, size=8)
npt.assert_equal(r0, r1)
distfn.random_state = np.random.RandomState(1234)
r2 = distfn.rvs(*args, size=8)
npt.assert_equal(r0, r2)
# check that np.random.Generator can be used (numpy >= 1.17)
if hasattr(np.random, 'default_rng'):
# obtain a np.random.Generator object
rng = np.random.default_rng(1234)
distfn.rvs(*args, size=1, random_state=rng)
# can override the instance-level random_state for an individual .rvs call
distfn.random_state = 2
orig_state = distfn.random_state.get_state()
r3 = distfn.rvs(*args, size=8, random_state=np.random.RandomState(1234))
npt.assert_equal(r0, r3)
# ... and that does not alter the instance-level random_state!
npt.assert_equal(distfn.random_state.get_state(), orig_state)
# finally, restore the random_state
distfn.random_state = rndm
def check_meth_dtype(distfn, arg, meths):
q0 = [0.25, 0.5, 0.75]
x0 = distfn.ppf(q0, *arg)
x_cast = [x0.astype(tp) for tp in
(np.int_, np.float16, np.float32, np.float64)]
for x in x_cast:
# casting may have clipped the values, exclude those
distfn._argcheck(*arg)
x = x[(distfn.a < x) & (x < distfn.b)]
for meth in meths:
val = meth(x, *arg)
npt.assert_(val.dtype == np.float_)
def check_ppf_dtype(distfn, arg):
q0 = np.asarray([0.25, 0.5, 0.75])
q_cast = [q0.astype(tp) for tp in (np.float16, np.float32, np.float64)]
for q in q_cast:
for meth in [distfn.ppf, distfn.isf]:
val = meth(q, *arg)
npt.assert_(val.dtype == np.float_)
def check_cmplx_deriv(distfn, arg):
# Distributions allow complex arguments.
def deriv(f, x, *arg):
x = np.asarray(x)
h = 1e-10
return (f(x + h*1j, *arg)/h).imag
x0 = distfn.ppf([0.25, 0.51, 0.75], *arg)
x_cast = [x0.astype(tp) for tp in
(np.int_, np.float16, np.float32, np.float64)]
for x in x_cast:
# casting may have clipped the values, exclude those
distfn._argcheck(*arg)
x = x[(distfn.a < x) & (x < distfn.b)]
pdf, cdf, sf = distfn.pdf(x, *arg), distfn.cdf(x, *arg), distfn.sf(x, *arg)
assert_allclose(deriv(distfn.cdf, x, *arg), pdf, rtol=1e-5)
assert_allclose(deriv(distfn.logcdf, x, *arg), pdf/cdf, rtol=1e-5)
assert_allclose(deriv(distfn.sf, x, *arg), -pdf, rtol=1e-5)
assert_allclose(deriv(distfn.logsf, x, *arg), -pdf/sf, rtol=1e-5)
assert_allclose(deriv(distfn.logpdf, x, *arg),
deriv(distfn.pdf, x, *arg) / distfn.pdf(x, *arg),
rtol=1e-5)
def check_pickling(distfn, args):
# check that a distribution instance pickles and unpickles
# pay special attention to the random_state property
# save the random_state (restore later)
rndm = distfn.random_state
# check unfrozen
distfn.random_state = 1234
distfn.rvs(*args, size=8)
s = pickle.dumps(distfn)
r0 = distfn.rvs(*args, size=8)
unpickled = pickle.loads(s)
r1 = unpickled.rvs(*args, size=8)
npt.assert_equal(r0, r1)
# also smoke test some methods
medians = [distfn.ppf(0.5, *args), unpickled.ppf(0.5, *args)]
npt.assert_equal(medians[0], medians[1])
npt.assert_equal(distfn.cdf(medians[0], *args),
unpickled.cdf(medians[1], *args))
# check frozen pickling/unpickling with rvs
frozen_dist = distfn(*args)
pkl = pickle.dumps(frozen_dist)
unpickled = pickle.loads(pkl)
r0 = frozen_dist.rvs(size=8)
r1 = unpickled.rvs(size=8)
npt.assert_equal(r0, r1)
# check pickling/unpickling of .fit method
if hasattr(distfn, "fit"):
fit_function = distfn.fit
pickled_fit_function = pickle.dumps(fit_function)
unpickled_fit_function = pickle.loads(pickled_fit_function)
assert fit_function.__name__ == unpickled_fit_function.__name__ == "fit"
# restore the random_state
distfn.random_state = rndm
def check_freezing(distfn, args):
# regression test for gh-11089: freezing a distribution fails
# if loc and/or scale are specified
if isinstance(distfn, stats.rv_continuous):
locscale = {'loc': 1, 'scale': 2}
else:
locscale = {'loc': 1}
rv = distfn(*args, **locscale)
assert rv.a == distfn(*args).a
assert rv.b == distfn(*args).b
def check_rvs_broadcast(distfunc, distname, allargs, shape, shape_only, otype):
np.random.seed(123)
sample = distfunc.rvs(*allargs)
assert_equal(sample.shape, shape, "%s: rvs failed to broadcast" % distname)
if not shape_only:
rvs = np.vectorize(lambda *allargs: distfunc.rvs(*allargs), otypes=otype)
np.random.seed(123)
expected = rvs(*allargs)
assert_allclose(sample, expected, rtol=1e-13)
def check_deprecation_warning_gh5982_moment(distfn, arg, distname):
# See description of cases that need to be tested in the definition of
# scipy.stats.rv_generic.moment
shapes = [] if distfn.shapes is None else distfn.shapes.split(", ")
kwd_shapes = dict(zip(shapes, arg or [])) # dictionary of shape kwds
n = kwd_shapes.pop('n', None)
message1 = "moment() missing 1 required positional argument"
message2 = "_parse_args() missing 1 required positional argument: 'n'"
message3 = "moment() got multiple values for first argument"
if 'n' in shapes:
expected = distfn.mean(n=n, **kwd_shapes)
# A1
res = distfn.moment(1, n=n, **kwd_shapes)
assert_allclose(res, expected)
# A2
with assert_raises(TypeError, match=re.escape(message1)):
distfn.moment(n=n, **kwd_shapes)
# A3
# if `n` is not provided at all
with assert_raises(TypeError, match=re.escape(message2)):
distfn.moment(1, **kwd_shapes)
# if `n` is provided as a positional argument
res = distfn.moment(1, *arg)
assert_allclose(res, expected)
# A4
with assert_raises(TypeError, match=re.escape(message1)):
distfn.moment(**kwd_shapes)
else:
expected = distfn.mean(**kwd_shapes)
# B1
with assert_raises(TypeError, match=re.escape(message3)):
res = distfn.moment(1, n=1, **kwd_shapes)
# B2
with np.testing.assert_warns(DeprecationWarning):
res = distfn.moment(n=1, **kwd_shapes)
assert_allclose(res, expected)
# B3
res = distfn.moment(1, *arg)
assert_allclose(res, expected)
# B4
with assert_raises(TypeError, match=re.escape(message1)):
distfn.moment(**kwd_shapes)
def check_deprecation_warning_gh5982_interval(distfn, arg, distname):
# See description of cases that need to be tested in the definition of
# scipy.stats.rv_generic.moment
shapes = [] if distfn.shapes is None else distfn.shapes.split(", ")
kwd_shapes = dict(zip(shapes, arg or [])) # dictionary of shape kwds
alpha = kwd_shapes.pop('alpha', None)
def my_interval(*args, **kwds):
return (distfn.ppf(0.25, *args, **kwds),
distfn.ppf(0.75, *args, **kwds))
message1 = "interval() missing 1 required positional argument"
message2 = "_parse_args() missing 1 required positional argument: 'alpha'"
message3 = "interval() got multiple values for first argument"
if 'alpha' in shapes:
expected = my_interval(alpha=alpha, **kwd_shapes)
# A1
res = distfn.interval(0.5, alpha=alpha, **kwd_shapes)
assert_allclose(res, expected)
# A2
with assert_raises(TypeError, match=re.escape(message1)):
distfn.interval(alpha=alpha, **kwd_shapes)
# A3
# if `alpha` is not provided at all
with assert_raises(TypeError, match=re.escape(message2)):
distfn.interval(0.5, **kwd_shapes)
# if `alpha` is provided as a positional argument
res = distfn.interval(0.5, *arg)
assert_allclose(res, expected)
# A4
with assert_raises(TypeError, match=re.escape(message1)):
distfn.interval(**kwd_shapes)
else:
expected = my_interval(**kwd_shapes)
# B1
with assert_raises(TypeError, match=re.escape(message3)):
res = distfn.interval(0.5, alpha=1, **kwd_shapes)
# B2
with np.testing.assert_warns(DeprecationWarning):
res = distfn.interval(alpha=0.5, **kwd_shapes)
assert_allclose(res, expected)
# B3
res = distfn.interval(0.5, *arg)
assert_allclose(res, expected)
# B4
with assert_raises(TypeError, match=re.escape(message1)):
distfn.interval(**kwd_shapes)