Inzynierka/Lib/site-packages/sklearn/utils/tests/test_testing.py

843 lines
24 KiB
Python
Raw Normal View History

2023-06-02 12:51:02 +02:00
import warnings
import unittest
import os
import atexit
import numpy as np
from scipy import sparse
import pytest
from sklearn.utils.deprecation import deprecated
from sklearn.utils.metaestimators import available_if, if_delegate_has_method
from sklearn.utils._readonly_array_wrapper import _test_sum
from sklearn.utils._testing import (
assert_raises,
assert_no_warnings,
set_random_state,
assert_raise_message,
ignore_warnings,
check_docstring_parameters,
assert_allclose_dense_sparse,
assert_raises_regex,
TempMemmap,
create_memmap_backed_data,
_delete_folder,
_convert_container,
raises,
assert_allclose,
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
def test_set_random_state():
lda = LinearDiscriminantAnalysis()
tree = DecisionTreeClassifier()
# Linear Discriminant Analysis doesn't have random state: smoke test
set_random_state(lda, 3)
set_random_state(tree, 3)
assert tree.random_state == 3
def test_assert_allclose_dense_sparse():
x = np.arange(9).reshape(3, 3)
msg = "Not equal to tolerance "
y = sparse.csc_matrix(x)
for X in [x, y]:
# basic compare
with pytest.raises(AssertionError, match=msg):
assert_allclose_dense_sparse(X, X * 2)
assert_allclose_dense_sparse(X, X)
with pytest.raises(ValueError, match="Can only compare two sparse"):
assert_allclose_dense_sparse(x, y)
A = sparse.diags(np.ones(5), offsets=0).tocsr()
B = sparse.csr_matrix(np.ones((1, 5)))
with pytest.raises(AssertionError, match="Arrays are not equal"):
assert_allclose_dense_sparse(B, A)
def test_assert_raises_msg():
with assert_raises_regex(AssertionError, "Hello world"):
with assert_raises(ValueError, msg="Hello world"):
pass
def test_assert_raise_message():
def _raise_ValueError(message):
raise ValueError(message)
def _no_raise():
pass
assert_raise_message(ValueError, "test", _raise_ValueError, "test")
assert_raises(
AssertionError,
assert_raise_message,
ValueError,
"something else",
_raise_ValueError,
"test",
)
assert_raises(
ValueError,
assert_raise_message,
TypeError,
"something else",
_raise_ValueError,
"test",
)
assert_raises(AssertionError, assert_raise_message, ValueError, "test", _no_raise)
# multiple exceptions in a tuple
assert_raises(
AssertionError,
assert_raise_message,
(ValueError, AttributeError),
"test",
_no_raise,
)
def test_ignore_warning():
# This check that ignore_warning decorator and context manager are working
# as expected
def _warning_function():
warnings.warn("deprecation warning", DeprecationWarning)
def _multiple_warning_function():
warnings.warn("deprecation warning", DeprecationWarning)
warnings.warn("deprecation warning")
# Check the function directly
assert_no_warnings(ignore_warnings(_warning_function))
assert_no_warnings(ignore_warnings(_warning_function, category=DeprecationWarning))
with pytest.warns(DeprecationWarning):
ignore_warnings(_warning_function, category=UserWarning)()
with pytest.warns(UserWarning):
ignore_warnings(_multiple_warning_function, category=FutureWarning)()
with pytest.warns(DeprecationWarning):
ignore_warnings(_multiple_warning_function, category=UserWarning)()
assert_no_warnings(
ignore_warnings(_warning_function, category=(DeprecationWarning, UserWarning))
)
# Check the decorator
@ignore_warnings
def decorator_no_warning():
_warning_function()
_multiple_warning_function()
@ignore_warnings(category=(DeprecationWarning, UserWarning))
def decorator_no_warning_multiple():
_multiple_warning_function()
@ignore_warnings(category=DeprecationWarning)
def decorator_no_deprecation_warning():
_warning_function()
@ignore_warnings(category=UserWarning)
def decorator_no_user_warning():
_warning_function()
@ignore_warnings(category=DeprecationWarning)
def decorator_no_deprecation_multiple_warning():
_multiple_warning_function()
@ignore_warnings(category=UserWarning)
def decorator_no_user_multiple_warning():
_multiple_warning_function()
assert_no_warnings(decorator_no_warning)
assert_no_warnings(decorator_no_warning_multiple)
assert_no_warnings(decorator_no_deprecation_warning)
with pytest.warns(DeprecationWarning):
decorator_no_user_warning()
with pytest.warns(UserWarning):
decorator_no_deprecation_multiple_warning()
with pytest.warns(DeprecationWarning):
decorator_no_user_multiple_warning()
# Check the context manager
def context_manager_no_warning():
with ignore_warnings():
_warning_function()
def context_manager_no_warning_multiple():
with ignore_warnings(category=(DeprecationWarning, UserWarning)):
_multiple_warning_function()
def context_manager_no_deprecation_warning():
with ignore_warnings(category=DeprecationWarning):
_warning_function()
def context_manager_no_user_warning():
with ignore_warnings(category=UserWarning):
_warning_function()
def context_manager_no_deprecation_multiple_warning():
with ignore_warnings(category=DeprecationWarning):
_multiple_warning_function()
def context_manager_no_user_multiple_warning():
with ignore_warnings(category=UserWarning):
_multiple_warning_function()
assert_no_warnings(context_manager_no_warning)
assert_no_warnings(context_manager_no_warning_multiple)
assert_no_warnings(context_manager_no_deprecation_warning)
with pytest.warns(DeprecationWarning):
context_manager_no_user_warning()
with pytest.warns(UserWarning):
context_manager_no_deprecation_multiple_warning()
with pytest.warns(DeprecationWarning):
context_manager_no_user_multiple_warning()
# Check that passing warning class as first positional argument
warning_class = UserWarning
match = "'obj' should be a callable.+you should use 'category=UserWarning'"
with pytest.raises(ValueError, match=match):
silence_warnings_func = ignore_warnings(warning_class)(_warning_function)
silence_warnings_func()
with pytest.raises(ValueError, match=match):
@ignore_warnings(warning_class)
def test():
pass
class TestWarns(unittest.TestCase):
def test_warn(self):
def f():
warnings.warn("yo")
return 3
with pytest.raises(AssertionError):
assert_no_warnings(f)
assert assert_no_warnings(lambda x: x, 1) == 1
# Tests for docstrings:
def f_ok(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
b : float
Parameter b
Returns
-------
c : list
Parameter c
"""
c = a + b
return c
def f_bad_sections(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
b : float
Parameter b
Results
-------
c : list
Parameter c
"""
c = a + b
return c
def f_bad_order(b, a):
"""Function f
Parameters
----------
a : int
Parameter a
b : float
Parameter b
Returns
-------
c : list
Parameter c
"""
c = a + b
return c
def f_too_many_param_docstring(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
b : int
Parameter b
c : int
Parameter c
Returns
-------
d : list
Parameter c
"""
d = a + b
return d
def f_missing(a, b):
"""Function f
Parameters
----------
a : int
Parameter a
Returns
-------
c : list
Parameter c
"""
c = a + b
return c
def f_check_param_definition(a, b, c, d, e):
"""Function f
Parameters
----------
a: int
Parameter a
b:
Parameter b
c :
This is parsed correctly in numpydoc 1.2
d:int
Parameter d
e
No typespec is allowed without colon
"""
return a + b + c + d
class Klass:
def f_missing(self, X, y):
pass
def f_bad_sections(self, X, y):
"""Function f
Parameter
---------
a : int
Parameter a
b : float
Parameter b
Results
-------
c : list
Parameter c
"""
pass
class MockEst:
def __init__(self):
"""MockEstimator"""
def fit(self, X, y):
return X
def predict(self, X):
return X
def predict_proba(self, X):
return X
def score(self, X):
return 1.0
class MockMetaEstimator:
def __init__(self, delegate):
"""MetaEstimator to check if doctest on delegated methods work.
Parameters
---------
delegate : estimator
Delegated estimator.
"""
self.delegate = delegate
@available_if(lambda self: hasattr(self.delegate, "predict"))
def predict(self, X):
"""This is available only if delegate has predict.
Parameters
----------
y : ndarray
Parameter y
"""
return self.delegate.predict(X)
@available_if(lambda self: hasattr(self.delegate, "score"))
@deprecated("Testing a deprecated delegated method")
def score(self, X):
"""This is available only if delegate has score.
Parameters
---------
y : ndarray
Parameter y
"""
@available_if(lambda self: hasattr(self.delegate, "predict_proba"))
def predict_proba(self, X):
"""This is available only if delegate has predict_proba.
Parameters
---------
X : ndarray
Parameter X
"""
return X
@deprecated("Testing deprecated function with wrong params")
def fit(self, X, y):
"""Incorrect docstring but should not be tested"""
class MockMetaEstimatorDeprecatedDelegation:
def __init__(self, delegate):
"""MetaEstimator to check if doctest on delegated methods work.
Parameters
---------
delegate : estimator
Delegated estimator.
"""
self.delegate = delegate
@if_delegate_has_method(delegate="delegate")
def predict(self, X):
"""This is available only if delegate has predict.
Parameters
----------
y : ndarray
Parameter y
"""
return self.delegate.predict(X)
@if_delegate_has_method(delegate="delegate")
@deprecated("Testing a deprecated delegated method")
def score(self, X):
"""This is available only if delegate has score.
Parameters
---------
y : ndarray
Parameter y
"""
@if_delegate_has_method(delegate="delegate")
def predict_proba(self, X):
"""This is available only if delegate has predict_proba.
Parameters
---------
X : ndarray
Parameter X
"""
return X
@deprecated("Testing deprecated function with wrong params")
def fit(self, X, y):
"""Incorrect docstring but should not be tested"""
@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
@pytest.mark.parametrize(
"mock_meta",
[
MockMetaEstimator(delegate=MockEst()),
MockMetaEstimatorDeprecatedDelegation(delegate=MockEst()),
],
)
def test_check_docstring_parameters(mock_meta):
pytest.importorskip(
"numpydoc",
reason="numpydoc is required to test the docstrings",
minversion="1.2.0",
)
incorrect = check_docstring_parameters(f_ok)
assert incorrect == []
incorrect = check_docstring_parameters(f_ok, ignore=["b"])
assert incorrect == []
incorrect = check_docstring_parameters(f_missing, ignore=["b"])
assert incorrect == []
with pytest.raises(RuntimeError, match="Unknown section Results"):
check_docstring_parameters(f_bad_sections)
with pytest.raises(RuntimeError, match="Unknown section Parameter"):
check_docstring_parameters(Klass.f_bad_sections)
incorrect = check_docstring_parameters(f_check_param_definition)
mock_meta_name = mock_meta.__class__.__name__
assert incorrect == [
"sklearn.utils.tests.test_testing.f_check_param_definition There "
"was no space between the param name and colon ('a: int')",
"sklearn.utils.tests.test_testing.f_check_param_definition There "
"was no space between the param name and colon ('b:')",
"sklearn.utils.tests.test_testing.f_check_param_definition There "
"was no space between the param name and colon ('d:int')",
]
messages = [
[
"In function: sklearn.utils.tests.test_testing.f_bad_order",
"There's a parameter name mismatch in function docstring w.r.t."
" function signature, at index 0 diff: 'b' != 'a'",
"Full diff:",
"- ['b', 'a']",
"+ ['a', 'b']",
],
[
"In function: "
+ "sklearn.utils.tests.test_testing.f_too_many_param_docstring",
"Parameters in function docstring have more items w.r.t. function"
" signature, first extra item: c",
"Full diff:",
"- ['a', 'b']",
"+ ['a', 'b', 'c']",
"? +++++",
],
[
"In function: sklearn.utils.tests.test_testing.f_missing",
"Parameters in function docstring have less items w.r.t. function"
" signature, first missing item: b",
"Full diff:",
"- ['a', 'b']",
"+ ['a']",
],
[
"In function: sklearn.utils.tests.test_testing.Klass.f_missing",
"Parameters in function docstring have less items w.r.t. function"
" signature, first missing item: X",
"Full diff:",
"- ['X', 'y']",
"+ []",
],
[
"In function: "
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}.predict",
"There's a parameter name mismatch in function docstring w.r.t."
" function signature, at index 0 diff: 'X' != 'y'",
"Full diff:",
"- ['X']",
"? ^",
"+ ['y']",
"? ^",
],
[
"In function: "
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}."
+ "predict_proba",
"potentially wrong underline length... ",
"Parameters ",
"--------- in ",
],
[
"In function: "
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}.score",
"potentially wrong underline length... ",
"Parameters ",
"--------- in ",
],
[
"In function: " + f"sklearn.utils.tests.test_testing.{mock_meta_name}.fit",
"Parameters in function docstring have less items w.r.t. function"
" signature, first missing item: X",
"Full diff:",
"- ['X', 'y']",
"+ []",
],
]
for msg, f in zip(
messages,
[
f_bad_order,
f_too_many_param_docstring,
f_missing,
Klass.f_missing,
mock_meta.predict,
mock_meta.predict_proba,
mock_meta.score,
mock_meta.fit,
],
):
incorrect = check_docstring_parameters(f)
assert msg == incorrect, '\n"%s"\n not in \n"%s"' % (msg, incorrect)
class RegistrationCounter:
def __init__(self):
self.nb_calls = 0
def __call__(self, to_register_func):
self.nb_calls += 1
assert to_register_func.func is _delete_folder
def check_memmap(input_array, mmap_data, mmap_mode="r"):
assert isinstance(mmap_data, np.memmap)
writeable = mmap_mode != "r"
assert mmap_data.flags.writeable is writeable
np.testing.assert_array_equal(input_array, mmap_data)
def test_tempmemmap(monkeypatch):
registration_counter = RegistrationCounter()
monkeypatch.setattr(atexit, "register", registration_counter)
input_array = np.ones(3)
with TempMemmap(input_array) as data:
check_memmap(input_array, data)
temp_folder = os.path.dirname(data.filename)
if os.name != "nt":
assert not os.path.exists(temp_folder)
assert registration_counter.nb_calls == 1
mmap_mode = "r+"
with TempMemmap(input_array, mmap_mode=mmap_mode) as data:
check_memmap(input_array, data, mmap_mode=mmap_mode)
temp_folder = os.path.dirname(data.filename)
if os.name != "nt":
assert not os.path.exists(temp_folder)
assert registration_counter.nb_calls == 2
@pytest.mark.parametrize("aligned", [False, True])
def test_create_memmap_backed_data(monkeypatch, aligned):
registration_counter = RegistrationCounter()
monkeypatch.setattr(atexit, "register", registration_counter)
input_array = np.ones(3)
data = create_memmap_backed_data(input_array, aligned=aligned)
check_memmap(input_array, data)
assert registration_counter.nb_calls == 1
data, folder = create_memmap_backed_data(
input_array, return_folder=True, aligned=aligned
)
check_memmap(input_array, data)
assert folder == os.path.dirname(data.filename)
assert registration_counter.nb_calls == 2
mmap_mode = "r+"
data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode, aligned=aligned)
check_memmap(input_array, data, mmap_mode)
assert registration_counter.nb_calls == 3
input_list = [input_array, input_array + 1, input_array + 2]
mmap_data_list = create_memmap_backed_data(input_list, aligned=aligned)
for input_array, data in zip(input_list, mmap_data_list):
check_memmap(input_array, data)
assert registration_counter.nb_calls == 4
with pytest.raises(
ValueError,
match=(
"When creating aligned memmap-backed arrays, input must be a single array"
" or a sequence of arrays"
),
):
create_memmap_backed_data([input_array, "not-an-array"], aligned=True)
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32, np.int64])
def test_memmap_on_contiguous_data(dtype):
"""Test memory mapped array on contiguous memoryview."""
x = np.arange(10).astype(dtype)
assert x.flags["C_CONTIGUOUS"]
assert x.flags["ALIGNED"]
# _test_sum consumes contiguous arrays
# def _test_sum(NUM_TYPES[::1] x):
sum_origin = _test_sum(x)
# now on memory mapped data
# aligned=True so avoid https://github.com/joblib/joblib/issues/563
# without alignment, this can produce segmentation faults, see
# https://github.com/scikit-learn/scikit-learn/pull/21654
x_mmap = create_memmap_backed_data(x, mmap_mode="r+", aligned=True)
sum_mmap = _test_sum(x_mmap)
assert sum_mmap == pytest.approx(sum_origin, rel=1e-11)
@pytest.mark.parametrize(
"constructor_name, container_type",
[
("list", list),
("tuple", tuple),
("array", np.ndarray),
("sparse", sparse.csr_matrix),
("sparse_csr", sparse.csr_matrix),
("sparse_csc", sparse.csc_matrix),
("dataframe", lambda: pytest.importorskip("pandas").DataFrame),
("series", lambda: pytest.importorskip("pandas").Series),
("index", lambda: pytest.importorskip("pandas").Index),
("slice", slice),
],
)
@pytest.mark.parametrize(
"dtype, superdtype",
[
(np.int32, np.integer),
(np.int64, np.integer),
(np.float32, np.floating),
(np.float64, np.floating),
],
)
def test_convert_container(
constructor_name,
container_type,
dtype,
superdtype,
):
"""Check that we convert the container to the right type of array with the
right data type."""
if constructor_name in ("dataframe", "series", "index"):
# delay the import of pandas within the function to only skip this test
# instead of the whole file
container_type = container_type()
container = [0, 1]
container_converted = _convert_container(
container,
constructor_name,
dtype=dtype,
)
assert isinstance(container_converted, container_type)
if constructor_name in ("list", "tuple", "index"):
# list and tuple will use Python class dtype: int, float
# pandas index will always use high precision: np.int64 and np.float64
assert np.issubdtype(type(container_converted[0]), superdtype)
elif hasattr(container_converted, "dtype"):
assert container_converted.dtype == dtype
elif hasattr(container_converted, "dtypes"):
assert container_converted.dtypes[0] == dtype
def test_raises():
# Tests for the raises context manager
# Proper type, no match
with raises(TypeError):
raise TypeError()
# Proper type, proper match
with raises(TypeError, match="how are you") as cm:
raise TypeError("hello how are you")
assert cm.raised_and_matched
# Proper type, proper match with multiple patterns
with raises(TypeError, match=["not this one", "how are you"]) as cm:
raise TypeError("hello how are you")
assert cm.raised_and_matched
# bad type, no match
with pytest.raises(ValueError, match="this will be raised"):
with raises(TypeError) as cm:
raise ValueError("this will be raised")
assert not cm.raised_and_matched
# Bad type, no match, with a err_msg
with pytest.raises(AssertionError, match="the failure message"):
with raises(TypeError, err_msg="the failure message") as cm:
raise ValueError()
assert not cm.raised_and_matched
# bad type, with match (is ignored anyway)
with pytest.raises(ValueError, match="this will be raised"):
with raises(TypeError, match="this is ignored") as cm:
raise ValueError("this will be raised")
assert not cm.raised_and_matched
# proper type but bad match
with pytest.raises(
AssertionError, match="should contain one of the following patterns"
):
with raises(TypeError, match="hello") as cm:
raise TypeError("Bad message")
assert not cm.raised_and_matched
# proper type but bad match, with err_msg
with pytest.raises(AssertionError, match="the failure message"):
with raises(TypeError, match="hello", err_msg="the failure message") as cm:
raise TypeError("Bad message")
assert not cm.raised_and_matched
# no raise with default may_pass=False
with pytest.raises(AssertionError, match="Did not raise"):
with raises(TypeError) as cm:
pass
assert not cm.raised_and_matched
# no raise with may_pass=True
with raises(TypeError, match="hello", may_pass=True) as cm:
pass # still OK
assert not cm.raised_and_matched
# Multiple exception types:
with raises((TypeError, ValueError)):
raise TypeError()
with raises((TypeError, ValueError)):
raise ValueError()
with pytest.raises(AssertionError):
with raises((TypeError, ValueError)):
pass
def test_float32_aware_assert_allclose():
# The relative tolerance for float32 inputs is 1e-4
assert_allclose(np.array([1.0 + 2e-5], dtype=np.float32), 1.0)
with pytest.raises(AssertionError):
assert_allclose(np.array([1.0 + 2e-4], dtype=np.float32), 1.0)
# The relative tolerance for other inputs is left to 1e-7 as in
# the original numpy version.
assert_allclose(np.array([1.0 + 2e-8], dtype=np.float64), 1.0)
with pytest.raises(AssertionError):
assert_allclose(np.array([1.0 + 2e-7], dtype=np.float64), 1.0)
# atol is left to 0.0 by default, even for float32
with pytest.raises(AssertionError):
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0)
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0, atol=2e-5)