161 lines
5.3 KiB
Python
161 lines
5.3 KiB
Python
import pytest
|
|
import numpy as np
|
|
from scipy import sparse
|
|
|
|
from sklearn.preprocessing import FunctionTransformer
|
|
from sklearn.utils._testing import (assert_array_equal,
|
|
assert_allclose_dense_sparse)
|
|
from sklearn.utils._testing import assert_warns_message, assert_no_warnings
|
|
|
|
|
|
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X):
|
|
def _func(X, *args, **kwargs):
|
|
args_store.append(X)
|
|
args_store.extend(args)
|
|
kwargs_store.update(kwargs)
|
|
return func(X)
|
|
|
|
return _func
|
|
|
|
|
|
def test_delegate_to_func():
|
|
# (args|kwargs)_store will hold the positional and keyword arguments
|
|
# passed to the function inside the FunctionTransformer.
|
|
args_store = []
|
|
kwargs_store = {}
|
|
X = np.arange(10).reshape((5, 2))
|
|
assert_array_equal(
|
|
FunctionTransformer(_make_func(args_store, kwargs_store)).transform(X),
|
|
X, 'transform should have returned X unchanged',
|
|
)
|
|
|
|
# The function should only have received X.
|
|
assert args_store == [X], ('Incorrect positional arguments passed to '
|
|
'func: {args}'.format(args=args_store))
|
|
|
|
assert not kwargs_store, ('Unexpected keyword arguments passed to '
|
|
'func: {args}'.format(args=kwargs_store))
|
|
|
|
# reset the argument stores.
|
|
args_store[:] = []
|
|
kwargs_store.clear()
|
|
transformed = FunctionTransformer(
|
|
_make_func(args_store, kwargs_store),
|
|
).transform(X)
|
|
|
|
assert_array_equal(transformed, X,
|
|
err_msg='transform should have returned X unchanged')
|
|
|
|
# The function should have received X
|
|
assert args_store == [X], ('Incorrect positional arguments passed '
|
|
'to func: {args}'.format(args=args_store))
|
|
|
|
assert not kwargs_store, ('Unexpected keyword arguments passed to '
|
|
'func: {args}'.format(args=kwargs_store))
|
|
|
|
|
|
def test_np_log():
|
|
X = np.arange(10).reshape((5, 2))
|
|
|
|
# Test that the numpy.log example still works.
|
|
assert_array_equal(
|
|
FunctionTransformer(np.log1p).transform(X),
|
|
np.log1p(X),
|
|
)
|
|
|
|
|
|
def test_kw_arg():
|
|
X = np.linspace(0, 1, num=10).reshape((5, 2))
|
|
|
|
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
|
|
|
|
# Test that rounding is correct
|
|
assert_array_equal(F.transform(X),
|
|
np.around(X, decimals=3))
|
|
|
|
|
|
def test_kw_arg_update():
|
|
X = np.linspace(0, 1, num=10).reshape((5, 2))
|
|
|
|
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
|
|
|
|
F.kw_args['decimals'] = 1
|
|
|
|
# Test that rounding is correct
|
|
assert_array_equal(F.transform(X), np.around(X, decimals=1))
|
|
|
|
|
|
def test_kw_arg_reset():
|
|
X = np.linspace(0, 1, num=10).reshape((5, 2))
|
|
|
|
F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
|
|
|
|
F.kw_args = dict(decimals=1)
|
|
|
|
# Test that rounding is correct
|
|
assert_array_equal(F.transform(X), np.around(X, decimals=1))
|
|
|
|
|
|
def test_inverse_transform():
|
|
X = np.array([1, 4, 9, 16]).reshape((2, 2))
|
|
|
|
# Test that inverse_transform works correctly
|
|
F = FunctionTransformer(
|
|
func=np.sqrt,
|
|
inverse_func=np.around, inv_kw_args=dict(decimals=3),
|
|
)
|
|
assert_array_equal(
|
|
F.inverse_transform(F.transform(X)),
|
|
np.around(np.sqrt(X), decimals=3),
|
|
)
|
|
|
|
|
|
def test_check_inverse():
|
|
X_dense = np.array([1, 4, 9, 16], dtype=np.float64).reshape((2, 2))
|
|
|
|
X_list = [X_dense,
|
|
sparse.csr_matrix(X_dense),
|
|
sparse.csc_matrix(X_dense)]
|
|
|
|
for X in X_list:
|
|
if sparse.issparse(X):
|
|
accept_sparse = True
|
|
else:
|
|
accept_sparse = False
|
|
trans = FunctionTransformer(func=np.sqrt,
|
|
inverse_func=np.around,
|
|
accept_sparse=accept_sparse,
|
|
check_inverse=True,
|
|
validate=True)
|
|
assert_warns_message(UserWarning,
|
|
"The provided functions are not strictly"
|
|
" inverse of each other. If you are sure you"
|
|
" want to proceed regardless, set"
|
|
" 'check_inverse=False'.",
|
|
trans.fit, X)
|
|
|
|
trans = FunctionTransformer(func=np.expm1,
|
|
inverse_func=np.log1p,
|
|
accept_sparse=accept_sparse,
|
|
check_inverse=True,
|
|
validate=True)
|
|
Xt = assert_no_warnings(trans.fit_transform, X)
|
|
assert_allclose_dense_sparse(X, trans.inverse_transform(Xt))
|
|
|
|
# check that we don't check inverse when one of the func or inverse is not
|
|
# provided.
|
|
trans = FunctionTransformer(func=np.expm1, inverse_func=None,
|
|
check_inverse=True, validate=True)
|
|
assert_no_warnings(trans.fit, X_dense)
|
|
trans = FunctionTransformer(func=None, inverse_func=np.expm1,
|
|
check_inverse=True, validate=True)
|
|
assert_no_warnings(trans.fit, X_dense)
|
|
|
|
|
|
def test_function_transformer_frame():
|
|
pd = pytest.importorskip('pandas')
|
|
X_df = pd.DataFrame(np.random.randn(100, 10))
|
|
transformer = FunctionTransformer()
|
|
X_df_trans = transformer.fit_transform(X_df)
|
|
assert hasattr(X_df_trans, 'loc')
|