233 lines
7.6 KiB
Python
233 lines
7.6 KiB
Python
"""Tools to support array_api."""
|
|
import numpy
|
|
from .._config import get_config
|
|
import scipy.special as special
|
|
|
|
|
|
class _ArrayAPIWrapper:
|
|
"""sklearn specific Array API compatibility wrapper
|
|
|
|
This wrapper makes it possible for scikit-learn maintainers to
|
|
deal with discrepancies between different implementations of the
|
|
Python array API standard and its evolution over time.
|
|
|
|
The Python array API standard specification:
|
|
https://data-apis.org/array-api/latest/
|
|
|
|
Documentation of the NumPy implementation:
|
|
https://numpy.org/neps/nep-0047-array-api-standard.html
|
|
"""
|
|
|
|
def __init__(self, array_namespace):
|
|
self._namespace = array_namespace
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._namespace, name)
|
|
|
|
def take(self, X, indices, *, axis):
|
|
# When array_api supports `take` we can use this directly
|
|
# https://github.com/data-apis/array-api/issues/177
|
|
if self._namespace.__name__ == "numpy.array_api":
|
|
X_np = numpy.take(X, indices, axis=axis)
|
|
return self._namespace.asarray(X_np)
|
|
|
|
# We only support axis in (0, 1) and ndim in (1, 2) because that is all we need
|
|
# in scikit-learn
|
|
if axis not in {0, 1}:
|
|
raise ValueError(f"Only axis in (0, 1) is supported. Got {axis}")
|
|
|
|
if X.ndim not in {1, 2}:
|
|
raise ValueError(f"Only X.ndim in (1, 2) is supported. Got {X.ndim}")
|
|
|
|
if axis == 0:
|
|
if X.ndim == 1:
|
|
selected = [X[i] for i in indices]
|
|
else: # X.ndim == 2
|
|
selected = [X[i, :] for i in indices]
|
|
else: # axis == 1
|
|
selected = [X[:, i] for i in indices]
|
|
return self._namespace.stack(selected, axis=axis)
|
|
|
|
|
|
class _NumPyApiWrapper:
|
|
"""Array API compat wrapper for any numpy version
|
|
|
|
NumPy < 1.22 does not expose the numpy.array_api namespace. This
|
|
wrapper makes it possible to write code that uses the standard
|
|
Array API while working with any version of NumPy supported by
|
|
scikit-learn.
|
|
|
|
See the `get_namespace()` public function for more details.
|
|
"""
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(numpy, name)
|
|
|
|
def astype(self, x, dtype, *, copy=True, casting="unsafe"):
|
|
# astype is not defined in the top level NumPy namespace
|
|
return x.astype(dtype, copy=copy, casting=casting)
|
|
|
|
def asarray(self, x, *, dtype=None, device=None, copy=None):
|
|
# Support copy in NumPy namespace
|
|
if copy is True:
|
|
return numpy.array(x, copy=True, dtype=dtype)
|
|
else:
|
|
return numpy.asarray(x, dtype=dtype)
|
|
|
|
def unique_inverse(self, x):
|
|
return numpy.unique(x, return_inverse=True)
|
|
|
|
def unique_counts(self, x):
|
|
return numpy.unique(x, return_counts=True)
|
|
|
|
def unique_values(self, x):
|
|
return numpy.unique(x)
|
|
|
|
def concat(self, arrays, *, axis=None):
|
|
return numpy.concatenate(arrays, axis=axis)
|
|
|
|
|
|
def get_namespace(*arrays):
|
|
"""Get namespace of arrays.
|
|
|
|
Introspect `arrays` arguments and return their common Array API
|
|
compatible namespace object, if any. NumPy 1.22 and later can
|
|
construct such containers using the `numpy.array_api` namespace
|
|
for instance.
|
|
|
|
See: https://numpy.org/neps/nep-0047-array-api-standard.html
|
|
|
|
If `arrays` are regular numpy arrays, an instance of the
|
|
`_NumPyApiWrapper` compatibility wrapper is returned instead.
|
|
|
|
Namespace support is not enabled by default. To enabled it
|
|
call:
|
|
|
|
sklearn.set_config(array_api_dispatch=True)
|
|
|
|
or:
|
|
|
|
with sklearn.config_context(array_api_dispatch=True):
|
|
# your code here
|
|
|
|
Otherwise an instance of the `_NumPyApiWrapper`
|
|
compatibility wrapper is always returned irrespective of
|
|
the fact that arrays implement the `__array_namespace__`
|
|
protocol or not.
|
|
|
|
Parameters
|
|
----------
|
|
*arrays : array objects
|
|
Array objects.
|
|
|
|
Returns
|
|
-------
|
|
namespace : module
|
|
Namespace shared by array objects.
|
|
|
|
is_array_api : bool
|
|
True of the arrays are containers that implement the Array API spec.
|
|
"""
|
|
# `arrays` contains one or more arrays, or possibly Python scalars (accepting
|
|
# those is a matter of taste, but doesn't seem unreasonable).
|
|
# Returns a tuple: (array_namespace, is_array_api)
|
|
|
|
if not get_config()["array_api_dispatch"]:
|
|
return _NumPyApiWrapper(), False
|
|
|
|
namespaces = {
|
|
x.__array_namespace__() if hasattr(x, "__array_namespace__") else None
|
|
for x in arrays
|
|
if not isinstance(x, (bool, int, float, complex))
|
|
}
|
|
|
|
if not namespaces:
|
|
# one could special-case np.ndarray above or use np.asarray here if
|
|
# older numpy versions need to be supported.
|
|
raise ValueError("Unrecognized array input")
|
|
|
|
if len(namespaces) != 1:
|
|
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
|
|
|
|
(xp,) = namespaces
|
|
if xp is None:
|
|
# Use numpy as default
|
|
return _NumPyApiWrapper(), False
|
|
|
|
return _ArrayAPIWrapper(xp), True
|
|
|
|
|
|
def _expit(X):
|
|
xp, _ = get_namespace(X)
|
|
if xp.__name__ in {"numpy", "numpy.array_api"}:
|
|
return xp.asarray(special.expit(numpy.asarray(X)))
|
|
|
|
return 1.0 / (1.0 + xp.exp(-X))
|
|
|
|
|
|
def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None):
|
|
"""Helper to support the order kwarg only for NumPy-backed arrays
|
|
|
|
Memory layout parameter `order` is not exposed in the Array API standard,
|
|
however some input validation code in scikit-learn needs to work both
|
|
for classes and functions that will leverage Array API only operations
|
|
and for code that inherently relies on NumPy backed data containers with
|
|
specific memory layout constraints (e.g. our own Cython code). The
|
|
purpose of this helper is to make it possible to share code for data
|
|
container validation without memory copies for both downstream use cases:
|
|
the `order` parameter is only enforced if the input array implementation
|
|
is NumPy based, otherwise `order` is just silently ignored.
|
|
"""
|
|
if xp is None:
|
|
xp, _ = get_namespace(array)
|
|
if xp.__name__ in {"numpy", "numpy.array_api"}:
|
|
# Use NumPy API to support order
|
|
array = numpy.asarray(array, order=order, dtype=dtype)
|
|
return xp.asarray(array, copy=copy)
|
|
else:
|
|
return xp.asarray(array, dtype=dtype, copy=copy)
|
|
|
|
|
|
def _convert_to_numpy(array, xp):
|
|
"""Convert X into a NumPy ndarray.
|
|
|
|
Only works on cupy.array_api and numpy.array_api and is used for testing.
|
|
"""
|
|
supported_array_api = ["numpy.array_api", "cupy.array_api"]
|
|
if xp.__name__ not in supported_array_api:
|
|
support_array_api_str = ", ".join(supported_array_api)
|
|
raise ValueError(f"Supported namespaces are: {support_array_api_str}")
|
|
|
|
if xp.__name__ == "cupy.array_api":
|
|
return array._array.get()
|
|
else:
|
|
return numpy.asarray(array)
|
|
|
|
|
|
def _estimator_with_converted_arrays(estimator, converter):
|
|
"""Create new estimator which converting all attributes that are arrays.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : Estimator
|
|
Estimator to convert
|
|
|
|
converter : callable
|
|
Callable that takes an array attribute and returns the converted array.
|
|
|
|
Returns
|
|
-------
|
|
new_estimator : Estimator
|
|
Convert estimator
|
|
"""
|
|
from sklearn.base import clone
|
|
|
|
new_estimator = clone(estimator)
|
|
for key, attribute in vars(estimator).items():
|
|
if hasattr(attribute, "__array_namespace__") or isinstance(
|
|
attribute, numpy.ndarray
|
|
):
|
|
attribute = converter(attribute)
|
|
setattr(new_estimator, key, attribute)
|
|
return new_estimator
|