projektAI/venv/Lib/site-packages/sklearn/preprocessing/tests/test_function_transformer.py
2021-06-06 22:13:05 +02:00

161 lines
5.3 KiB
Python

import pytest
import numpy as np
from scipy import sparse
from sklearn.preprocessing import FunctionTransformer
from sklearn.utils._testing import (assert_array_equal,
assert_allclose_dense_sparse)
from sklearn.utils._testing import assert_warns_message, assert_no_warnings
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X):
def _func(X, *args, **kwargs):
args_store.append(X)
args_store.extend(args)
kwargs_store.update(kwargs)
return func(X)
return _func
def test_delegate_to_func():
# (args|kwargs)_store will hold the positional and keyword arguments
# passed to the function inside the FunctionTransformer.
args_store = []
kwargs_store = {}
X = np.arange(10).reshape((5, 2))
assert_array_equal(
FunctionTransformer(_make_func(args_store, kwargs_store)).transform(X),
X, 'transform should have returned X unchanged',
)
# The function should only have received X.
assert args_store == [X], ('Incorrect positional arguments passed to '
'func: {args}'.format(args=args_store))
assert not kwargs_store, ('Unexpected keyword arguments passed to '
'func: {args}'.format(args=kwargs_store))
# reset the argument stores.
args_store[:] = []
kwargs_store.clear()
transformed = FunctionTransformer(
_make_func(args_store, kwargs_store),
).transform(X)
assert_array_equal(transformed, X,
err_msg='transform should have returned X unchanged')
# The function should have received X
assert args_store == [X], ('Incorrect positional arguments passed '
'to func: {args}'.format(args=args_store))
assert not kwargs_store, ('Unexpected keyword arguments passed to '
'func: {args}'.format(args=kwargs_store))
def test_np_log():
X = np.arange(10).reshape((5, 2))
# Test that the numpy.log example still works.
assert_array_equal(
FunctionTransformer(np.log1p).transform(X),
np.log1p(X),
)
def test_kw_arg():
X = np.linspace(0, 1, num=10).reshape((5, 2))
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
# Test that rounding is correct
assert_array_equal(F.transform(X),
np.around(X, decimals=3))
def test_kw_arg_update():
X = np.linspace(0, 1, num=10).reshape((5, 2))
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
F.kw_args['decimals'] = 1
# Test that rounding is correct
assert_array_equal(F.transform(X), np.around(X, decimals=1))
def test_kw_arg_reset():
X = np.linspace(0, 1, num=10).reshape((5, 2))
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
F.kw_args = dict(decimals=1)
# Test that rounding is correct
assert_array_equal(F.transform(X), np.around(X, decimals=1))
def test_inverse_transform():
X = np.array([1, 4, 9, 16]).reshape((2, 2))
# Test that inverse_transform works correctly
F = FunctionTransformer(
func=np.sqrt,
inverse_func=np.around, inv_kw_args=dict(decimals=3),
)
assert_array_equal(
F.inverse_transform(F.transform(X)),
np.around(np.sqrt(X), decimals=3),
)
def test_check_inverse():
X_dense = np.array([1, 4, 9, 16], dtype=np.float64).reshape((2, 2))
X_list = [X_dense,
sparse.csr_matrix(X_dense),
sparse.csc_matrix(X_dense)]
for X in X_list:
if sparse.issparse(X):
accept_sparse = True
else:
accept_sparse = False
trans = FunctionTransformer(func=np.sqrt,
inverse_func=np.around,
accept_sparse=accept_sparse,
check_inverse=True,
validate=True)
assert_warns_message(UserWarning,
"The provided functions are not strictly"
" inverse of each other. If you are sure you"
" want to proceed regardless, set"
" 'check_inverse=False'.",
trans.fit, X)
trans = FunctionTransformer(func=np.expm1,
inverse_func=np.log1p,
accept_sparse=accept_sparse,
check_inverse=True,
validate=True)
Xt = assert_no_warnings(trans.fit_transform, X)
assert_allclose_dense_sparse(X, trans.inverse_transform(Xt))
# check that we don't check inverse when one of the func or inverse is not
# provided.
trans = FunctionTransformer(func=np.expm1, inverse_func=None,
check_inverse=True, validate=True)
assert_no_warnings(trans.fit, X_dense)
trans = FunctionTransformer(func=None, inverse_func=np.expm1,
check_inverse=True, validate=True)
assert_no_warnings(trans.fit, X_dense)
def test_function_transformer_frame():
pd = pytest.importorskip('pandas')
X_df = pd.DataFrame(np.random.randn(100, 10))
transformer = FunctionTransformer()
X_df_trans = transformer.fit_transform(X_df)
assert hasattr(X_df_trans, 'loc')