3RNN/Lib/site-packages/sklearn/utils/__init__.py
2024-05-26 19:49:15 +02:00

123 lines
3.2 KiB
Python

"""
The :mod:`sklearn.utils` module includes various utilities.
"""
import warnings
from collections.abc import Sequence
import numpy as np
from ..exceptions import DataConversionWarning
from . import _joblib, metadata_routing
from ._bunch import Bunch
from ._chunking import gen_batches, gen_even_slices
from ._estimator_html_repr import estimator_html_repr
# Make _safe_indexing importable from here for backward compat as this particular
# helper is considered semi-private and typically very useful for third-party
# libraries that want to comply with scikit-learn's estimator API. In particular,
# _safe_indexing was included in our public API documentation despite the leading
# `_` in its name.
from ._indexing import (
_safe_indexing, # noqa
resample,
shuffle,
)
from ._mask import safe_mask
from .class_weight import compute_class_weight, compute_sample_weight
from .deprecation import deprecated
from .discovery import all_estimators
from .extmath import safe_sqr
from .murmurhash import murmurhash3_32
from .validation import (
as_float_array,
assert_all_finite,
check_array,
check_consistent_length,
check_random_state,
check_scalar,
check_symmetric,
check_X_y,
column_or_1d,
indexable,
)
# TODO(1.7): remove parallel_backend and register_parallel_backend
msg = "deprecated in 1.5 to be removed in 1.7. Use joblib.{} instead."
register_parallel_backend = deprecated(msg)(_joblib.register_parallel_backend)
# if a class, deprecated will change the object in _joblib module so we need to subclass
@deprecated(msg)
class parallel_backend(_joblib.parallel_backend):
pass
__all__ = [
"murmurhash3_32",
"as_float_array",
"assert_all_finite",
"check_array",
"check_random_state",
"compute_class_weight",
"compute_sample_weight",
"column_or_1d",
"check_consistent_length",
"check_X_y",
"check_scalar",
"indexable",
"check_symmetric",
"deprecated",
"parallel_backend",
"register_parallel_backend",
"resample",
"shuffle",
"all_estimators",
"DataConversionWarning",
"estimator_html_repr",
"Bunch",
"metadata_routing",
"safe_sqr",
"safe_mask",
"gen_batches",
"gen_even_slices",
]
# TODO(1.7): remove
def __getattr__(name):
if name == "IS_PYPY":
warnings.warn(
"IS_PYPY is deprecated and will be removed in 1.7.",
FutureWarning,
)
from .fixes import _IS_PYPY
return _IS_PYPY
raise AttributeError(f"module {__name__} has no attribute {name}")
# TODO(1.7): remove tosequence
@deprecated("tosequence was deprecated in 1.5 and will be removed in 1.7")
def tosequence(x):
"""Cast iterable x to a Sequence, avoiding a copy if possible.
Parameters
----------
x : iterable
The iterable to be converted.
Returns
-------
x : Sequence
If `x` is a NumPy array, it returns it as a `ndarray`. If `x`
is a `Sequence`, `x` is returned as-is. If `x` is from any other
type, `x` is returned casted as a list.
"""
if isinstance(x, np.ndarray):
return np.asarray(x)
elif isinstance(x, Sequence):
return x
else:
return list(x)