59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
# Authors: Gael Varoquaux <gael.varoquaux@normalesup.org>
|
|
# Justin Vincent
|
|
# Lars Buitinck
|
|
# License: BSD 3 clause
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import scipy.stats
|
|
|
|
from sklearn.utils._testing import assert_array_equal
|
|
|
|
from sklearn.utils.fixes import _object_dtype_isnan, delayed, loguniform
|
|
|
|
|
|
@pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1]))
|
|
def test_object_dtype_isnan(dtype, val):
|
|
X = np.array([[val, np.nan], [np.nan, val]], dtype=dtype)
|
|
|
|
expected_mask = np.array([[False, True], [True, False]])
|
|
|
|
mask = _object_dtype_isnan(X)
|
|
|
|
assert_array_equal(mask, expected_mask)
|
|
|
|
|
|
@pytest.mark.parametrize("low,high,base", [(-1, 0, 10), (0, 2, np.exp(1)), (-1, 1, 2)])
|
|
def test_loguniform(low, high, base):
|
|
rv = loguniform(base**low, base**high)
|
|
assert isinstance(rv, scipy.stats._distn_infrastructure.rv_frozen)
|
|
rvs = rv.rvs(size=2000, random_state=0)
|
|
|
|
# Test the basics; right bounds, right size
|
|
assert (base**low <= rvs).all() and (rvs <= base**high).all()
|
|
assert len(rvs) == 2000
|
|
|
|
# Test that it's actually (fairly) uniform
|
|
log_rvs = np.array([math.log(x, base) for x in rvs])
|
|
counts, _ = np.histogram(log_rvs)
|
|
assert counts.mean() == 200
|
|
assert np.abs(counts - counts.mean()).max() <= 40
|
|
|
|
# Test that random_state works
|
|
assert loguniform(base**low, base**high).rvs(random_state=0) == loguniform(
|
|
base**low, base**high
|
|
).rvs(random_state=0)
|
|
|
|
|
|
def test_delayed_deprecation():
|
|
"""Check that we issue the FutureWarning regarding the deprecation of delayed."""
|
|
|
|
def func(x):
|
|
return x
|
|
|
|
warn_msg = "The function `delayed` has been moved from `sklearn.utils.fixes`"
|
|
with pytest.warns(FutureWarning, match=warn_msg):
|
|
delayed(func)
|