190 lines
6.6 KiB
Python
190 lines
6.6 KiB
Python
|
import numpy
|
||
|
from numpy.testing import assert_array_equal
|
||
|
import pytest
|
||
|
|
||
|
from sklearn.base import BaseEstimator
|
||
|
from sklearn.utils._array_api import get_namespace
|
||
|
from sklearn.utils._array_api import _NumPyApiWrapper
|
||
|
from sklearn.utils._array_api import _ArrayAPIWrapper
|
||
|
from sklearn.utils._array_api import _asarray_with_order
|
||
|
from sklearn.utils._array_api import _convert_to_numpy
|
||
|
from sklearn.utils._array_api import _estimator_with_converted_arrays
|
||
|
from sklearn._config import config_context
|
||
|
|
||
|
pytestmark = pytest.mark.filterwarnings(
|
||
|
"ignore:The numpy.array_api submodule:UserWarning"
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_get_namespace_ndarray():
|
||
|
"""Test get_namespace on NumPy ndarrays."""
|
||
|
pytest.importorskip("numpy.array_api")
|
||
|
|
||
|
X_np = numpy.asarray([[1, 2, 3]])
|
||
|
|
||
|
# Dispatching on Numpy regardless or the value of array_api_dispatch.
|
||
|
for array_api_dispatch in [True, False]:
|
||
|
with config_context(array_api_dispatch=array_api_dispatch):
|
||
|
xp_out, is_array_api = get_namespace(X_np)
|
||
|
assert not is_array_api
|
||
|
assert isinstance(xp_out, _NumPyApiWrapper)
|
||
|
|
||
|
|
||
|
def test_get_namespace_array_api():
|
||
|
"""Test get_namespace for ArrayAPI arrays."""
|
||
|
xp = pytest.importorskip("numpy.array_api")
|
||
|
|
||
|
X_np = numpy.asarray([[1, 2, 3]])
|
||
|
X_xp = xp.asarray(X_np)
|
||
|
with config_context(array_api_dispatch=True):
|
||
|
xp_out, is_array_api = get_namespace(X_xp)
|
||
|
assert is_array_api
|
||
|
assert isinstance(xp_out, _ArrayAPIWrapper)
|
||
|
|
||
|
# check errors
|
||
|
with pytest.raises(ValueError, match="Multiple namespaces"):
|
||
|
get_namespace(X_np, X_xp)
|
||
|
|
||
|
with pytest.raises(ValueError, match="Unrecognized array input"):
|
||
|
get_namespace(1)
|
||
|
|
||
|
|
||
|
class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
|
||
|
"""API wrapper that has an adjustable name. Used for testing."""
|
||
|
|
||
|
def __init__(self, array_namespace, name):
|
||
|
super().__init__(array_namespace=array_namespace)
|
||
|
self.__name__ = name
|
||
|
|
||
|
|
||
|
def test_array_api_wrapper_astype():
|
||
|
"""Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy."""
|
||
|
numpy_array_api = pytest.importorskip("numpy.array_api")
|
||
|
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
|
||
|
xp = _ArrayAPIWrapper(xp_)
|
||
|
|
||
|
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
|
||
|
X_converted = xp.astype(X, xp.float32)
|
||
|
assert X_converted.dtype == xp.float32
|
||
|
|
||
|
X_converted = xp.asarray(X, dtype=xp.float32)
|
||
|
assert X_converted.dtype == xp.float32
|
||
|
|
||
|
|
||
|
def test_array_api_wrapper_take_for_numpy_api():
|
||
|
"""Test that fast path is called for numpy.array_api."""
|
||
|
numpy_array_api = pytest.importorskip("numpy.array_api")
|
||
|
# USe the same name as numpy.array_api
|
||
|
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api")
|
||
|
xp = _ArrayAPIWrapper(xp_)
|
||
|
|
||
|
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
|
||
|
X_take = xp.take(X, xp.asarray([1]), axis=0)
|
||
|
assert hasattr(X_take, "__array_namespace__")
|
||
|
assert_array_equal(X_take, numpy.take(X, [1], axis=0))
|
||
|
|
||
|
|
||
|
def test_array_api_wrapper_take():
|
||
|
"""Test _ArrayAPIWrapper API for take."""
|
||
|
numpy_array_api = pytest.importorskip("numpy.array_api")
|
||
|
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
|
||
|
xp = _ArrayAPIWrapper(xp_)
|
||
|
|
||
|
# Check take compared to NumPy's with axis=0
|
||
|
X_1d = xp.asarray([1, 2, 3], dtype=xp.float64)
|
||
|
X_take = xp.take(X_1d, xp.asarray([1]), axis=0)
|
||
|
assert hasattr(X_take, "__array_namespace__")
|
||
|
assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0))
|
||
|
|
||
|
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
|
||
|
X_take = xp.take(X, xp.asarray([0]), axis=0)
|
||
|
assert hasattr(X_take, "__array_namespace__")
|
||
|
assert_array_equal(X_take, numpy.take(X, [0], axis=0))
|
||
|
|
||
|
# Check take compared to NumPy's with axis=1
|
||
|
X_take = xp.take(X, xp.asarray([0, 2]), axis=1)
|
||
|
assert hasattr(X_take, "__array_namespace__")
|
||
|
assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1))
|
||
|
|
||
|
with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"):
|
||
|
xp.take(X, xp.asarray([0]), axis=2)
|
||
|
|
||
|
with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"):
|
||
|
xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("is_array_api", [True, False])
|
||
|
def test_asarray_with_order(is_array_api):
|
||
|
"""Test _asarray_with_order passes along order for NumPy arrays."""
|
||
|
if is_array_api:
|
||
|
xp = pytest.importorskip("numpy.array_api")
|
||
|
else:
|
||
|
xp = numpy
|
||
|
|
||
|
X = xp.asarray([1.2, 3.4, 5.1])
|
||
|
X_new = _asarray_with_order(X, order="F")
|
||
|
|
||
|
X_new_np = numpy.asarray(X_new)
|
||
|
assert X_new_np.flags["F_CONTIGUOUS"]
|
||
|
|
||
|
|
||
|
def test_asarray_with_order_ignored():
|
||
|
"""Test _asarray_with_order ignores order for Generic ArrayAPI."""
|
||
|
xp = pytest.importorskip("numpy.array_api")
|
||
|
xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")
|
||
|
|
||
|
X = numpy.asarray([[1.2, 3.4, 5.1], [3.4, 5.5, 1.2]], order="C")
|
||
|
X = xp_.asarray(X)
|
||
|
|
||
|
X_new = _asarray_with_order(X, order="F", xp=xp_)
|
||
|
|
||
|
X_new_np = numpy.asarray(X_new)
|
||
|
assert X_new_np.flags["C_CONTIGUOUS"]
|
||
|
assert not X_new_np.flags["F_CONTIGUOUS"]
|
||
|
|
||
|
|
||
|
def test_convert_to_numpy_error():
|
||
|
"""Test convert to numpy errors for unsupported namespaces."""
|
||
|
xp = pytest.importorskip("numpy.array_api")
|
||
|
xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")
|
||
|
|
||
|
X = xp_.asarray([1.2, 3.4])
|
||
|
|
||
|
with pytest.raises(ValueError, match="Supported namespaces are:"):
|
||
|
_convert_to_numpy(X, xp=xp_)
|
||
|
|
||
|
|
||
|
class SimpleEstimator(BaseEstimator):
|
||
|
def fit(self, X, y=None):
|
||
|
self.X_ = X
|
||
|
self.n_features_ = X.shape[0]
|
||
|
return self
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
|
||
|
def test_convert_estimator_to_ndarray(array_namespace):
|
||
|
"""Convert estimator attributes to ndarray."""
|
||
|
xp = pytest.importorskip(array_namespace)
|
||
|
|
||
|
if array_namespace == "numpy.array_api":
|
||
|
converter = lambda array: numpy.asarray(array) # noqa
|
||
|
else: # pragma: no cover
|
||
|
converter = lambda array: array._array.get() # noqa
|
||
|
|
||
|
X = xp.asarray([[1.3, 4.5]])
|
||
|
est = SimpleEstimator().fit(X)
|
||
|
|
||
|
new_est = _estimator_with_converted_arrays(est, converter)
|
||
|
assert isinstance(new_est.X_, numpy.ndarray)
|
||
|
|
||
|
|
||
|
def test_convert_estimator_to_array_api():
|
||
|
"""Convert estimator attributes to ArrayAPI arrays."""
|
||
|
xp = pytest.importorskip("numpy.array_api")
|
||
|
|
||
|
X_np = numpy.asarray([[1.3, 4.5]])
|
||
|
est = SimpleEstimator().fit(X_np)
|
||
|
|
||
|
new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
|
||
|
assert hasattr(new_est.X_, "__array_namespace__")
|