Inzynierka/Lib/site-packages/sklearn/utils/tests/test_array_api.py
2023-06-02 12:51:02 +02:00

190 lines
6.6 KiB
Python

import numpy
from numpy.testing import assert_array_equal
import pytest
from sklearn.base import BaseEstimator
from sklearn.utils._array_api import get_namespace
from sklearn.utils._array_api import _NumPyApiWrapper
from sklearn.utils._array_api import _ArrayAPIWrapper
from sklearn.utils._array_api import _asarray_with_order
from sklearn.utils._array_api import _convert_to_numpy
from sklearn.utils._array_api import _estimator_with_converted_arrays
from sklearn._config import config_context
pytestmark = pytest.mark.filterwarnings(
"ignore:The numpy.array_api submodule:UserWarning"
)
def test_get_namespace_ndarray():
"""Test get_namespace on NumPy ndarrays."""
pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1, 2, 3]])
# Dispatching on Numpy regardless or the value of array_api_dispatch.
for array_api_dispatch in [True, False]:
with config_context(array_api_dispatch=array_api_dispatch):
xp_out, is_array_api = get_namespace(X_np)
assert not is_array_api
assert isinstance(xp_out, _NumPyApiWrapper)
def test_get_namespace_array_api():
"""Test get_namespace for ArrayAPI arrays."""
xp = pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1, 2, 3]])
X_xp = xp.asarray(X_np)
with config_context(array_api_dispatch=True):
xp_out, is_array_api = get_namespace(X_xp)
assert is_array_api
assert isinstance(xp_out, _ArrayAPIWrapper)
# check errors
with pytest.raises(ValueError, match="Multiple namespaces"):
get_namespace(X_np, X_xp)
with pytest.raises(ValueError, match="Unrecognized array input"):
get_namespace(1)
class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
"""API wrapper that has an adjustable name. Used for testing."""
def __init__(self, array_namespace, name):
super().__init__(array_namespace=array_namespace)
self.__name__ = name
def test_array_api_wrapper_astype():
"""Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy."""
numpy_array_api = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
X_converted = xp.astype(X, xp.float32)
assert X_converted.dtype == xp.float32
X_converted = xp.asarray(X, dtype=xp.float32)
assert X_converted.dtype == xp.float32
def test_array_api_wrapper_take_for_numpy_api():
"""Test that fast path is called for numpy.array_api."""
numpy_array_api = pytest.importorskip("numpy.array_api")
# USe the same name as numpy.array_api
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
X_take = xp.take(X, xp.asarray([1]), axis=0)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X, [1], axis=0))
def test_array_api_wrapper_take():
"""Test _ArrayAPIWrapper API for take."""
numpy_array_api = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
# Check take compared to NumPy's with axis=0
X_1d = xp.asarray([1, 2, 3], dtype=xp.float64)
X_take = xp.take(X_1d, xp.asarray([1]), axis=0)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0))
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
X_take = xp.take(X, xp.asarray([0]), axis=0)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X, [0], axis=0))
# Check take compared to NumPy's with axis=1
X_take = xp.take(X, xp.asarray([0, 2]), axis=1)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1))
with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"):
xp.take(X, xp.asarray([0]), axis=2)
with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"):
xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0)
@pytest.mark.parametrize("is_array_api", [True, False])
def test_asarray_with_order(is_array_api):
"""Test _asarray_with_order passes along order for NumPy arrays."""
if is_array_api:
xp = pytest.importorskip("numpy.array_api")
else:
xp = numpy
X = xp.asarray([1.2, 3.4, 5.1])
X_new = _asarray_with_order(X, order="F")
X_new_np = numpy.asarray(X_new)
assert X_new_np.flags["F_CONTIGUOUS"]
def test_asarray_with_order_ignored():
"""Test _asarray_with_order ignores order for Generic ArrayAPI."""
xp = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")
X = numpy.asarray([[1.2, 3.4, 5.1], [3.4, 5.5, 1.2]], order="C")
X = xp_.asarray(X)
X_new = _asarray_with_order(X, order="F", xp=xp_)
X_new_np = numpy.asarray(X_new)
assert X_new_np.flags["C_CONTIGUOUS"]
assert not X_new_np.flags["F_CONTIGUOUS"]
def test_convert_to_numpy_error():
"""Test convert to numpy errors for unsupported namespaces."""
xp = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")
X = xp_.asarray([1.2, 3.4])
with pytest.raises(ValueError, match="Supported namespaces are:"):
_convert_to_numpy(X, xp=xp_)
class SimpleEstimator(BaseEstimator):
def fit(self, X, y=None):
self.X_ = X
self.n_features_ = X.shape[0]
return self
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
def test_convert_estimator_to_ndarray(array_namespace):
"""Convert estimator attributes to ndarray."""
xp = pytest.importorskip(array_namespace)
if array_namespace == "numpy.array_api":
converter = lambda array: numpy.asarray(array) # noqa
else: # pragma: no cover
converter = lambda array: array._array.get() # noqa
X = xp.asarray([[1.3, 4.5]])
est = SimpleEstimator().fit(X)
new_est = _estimator_with_converted_arrays(est, converter)
assert isinstance(new_est.X_, numpy.ndarray)
def test_convert_estimator_to_array_api():
"""Convert estimator attributes to ArrayAPI arrays."""
xp = pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1.3, 4.5]])
est = SimpleEstimator().fit(X_np)
new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
assert hasattr(new_est.X_, "__array_namespace__")