521 lines
17 KiB
Python
521 lines
17 KiB
Python
|
from functools import partial
|
||
|
|
||
|
import numpy as np
|
||
|
from numpy.testing import assert_array_equal
|
||
|
|
||
|
from sklearn.base import (
|
||
|
BaseEstimator,
|
||
|
ClassifierMixin,
|
||
|
MetaEstimatorMixin,
|
||
|
RegressorMixin,
|
||
|
TransformerMixin,
|
||
|
clone,
|
||
|
)
|
||
|
from sklearn.metrics._scorer import _Scorer, mean_squared_error
|
||
|
from sklearn.model_selection import BaseCrossValidator
|
||
|
from sklearn.model_selection._split import GroupsConsumerMixin
|
||
|
from sklearn.utils._metadata_requests import (
|
||
|
SIMPLE_METHODS,
|
||
|
)
|
||
|
from sklearn.utils.metadata_routing import (
|
||
|
MetadataRouter,
|
||
|
MethodMapping,
|
||
|
process_routing,
|
||
|
)
|
||
|
from sklearn.utils.multiclass import _check_partial_fit_first_call
|
||
|
|
||
|
|
||
|
def record_metadata(obj, method, record_default=True, **kwargs):
|
||
|
"""Utility function to store passed metadata to a method.
|
||
|
|
||
|
If record_default is False, kwargs whose values are "default" are skipped.
|
||
|
This is so that checks on keyword arguments whose default was not changed
|
||
|
are skipped.
|
||
|
|
||
|
"""
|
||
|
if not hasattr(obj, "_records"):
|
||
|
obj._records = {}
|
||
|
if not record_default:
|
||
|
kwargs = {
|
||
|
key: val
|
||
|
for key, val in kwargs.items()
|
||
|
if not isinstance(val, str) or (val != "default")
|
||
|
}
|
||
|
obj._records[method] = kwargs
|
||
|
|
||
|
|
||
|
def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
|
||
|
"""Check whether the expected metadata is passed to the object's method.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
obj : estimator object
|
||
|
sub-estimator to check routed params for
|
||
|
method : str
|
||
|
sub-estimator's method where metadata is routed to
|
||
|
split_params : tuple, default=empty
|
||
|
specifies any parameters which are to be checked as being a subset
|
||
|
of the original values
|
||
|
**kwargs : dict
|
||
|
passed metadata
|
||
|
"""
|
||
|
records = getattr(obj, "_records", dict()).get(method, dict())
|
||
|
assert set(kwargs.keys()) == set(
|
||
|
records.keys()
|
||
|
), f"Expected {kwargs.keys()} vs {records.keys()}"
|
||
|
for key, value in kwargs.items():
|
||
|
recorded_value = records[key]
|
||
|
# The following condition is used to check for any specified parameters
|
||
|
# being a subset of the original values
|
||
|
if key in split_params and recorded_value is not None:
|
||
|
assert np.isin(recorded_value, value).all()
|
||
|
else:
|
||
|
if isinstance(recorded_value, np.ndarray):
|
||
|
assert_array_equal(recorded_value, value)
|
||
|
else:
|
||
|
assert recorded_value is value, f"Expected {recorded_value} vs {value}"
|
||
|
|
||
|
|
||
|
record_metadata_not_default = partial(record_metadata, record_default=False)
|
||
|
|
||
|
|
||
|
def assert_request_is_empty(metadata_request, exclude=None):
|
||
|
"""Check if a metadata request dict is empty.
|
||
|
|
||
|
One can exclude a method or a list of methods from the check using the
|
||
|
``exclude`` parameter. If metadata_request is a MetadataRouter, then
|
||
|
``exclude`` can be of the form ``{"object" : [method, ...]}``.
|
||
|
"""
|
||
|
if isinstance(metadata_request, MetadataRouter):
|
||
|
for name, route_mapping in metadata_request:
|
||
|
if exclude is not None and name in exclude:
|
||
|
_exclude = exclude[name]
|
||
|
else:
|
||
|
_exclude = None
|
||
|
assert_request_is_empty(route_mapping.router, exclude=_exclude)
|
||
|
return
|
||
|
|
||
|
exclude = [] if exclude is None else exclude
|
||
|
for method in SIMPLE_METHODS:
|
||
|
if method in exclude:
|
||
|
continue
|
||
|
mmr = getattr(metadata_request, method)
|
||
|
props = [
|
||
|
prop
|
||
|
for prop, alias in mmr.requests.items()
|
||
|
if isinstance(alias, str) or alias is not None
|
||
|
]
|
||
|
assert not props
|
||
|
|
||
|
|
||
|
def assert_request_equal(request, dictionary):
|
||
|
for method, requests in dictionary.items():
|
||
|
mmr = getattr(request, method)
|
||
|
assert mmr.requests == requests
|
||
|
|
||
|
empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary]
|
||
|
for method in empty_methods:
|
||
|
assert not len(getattr(request, method).requests)
|
||
|
|
||
|
|
||
|
class _Registry(list):
|
||
|
# This list is used to get a reference to the sub-estimators, which are not
|
||
|
# necessarily stored on the metaestimator. We need to override __deepcopy__
|
||
|
# because the sub-estimators are probably cloned, which would result in a
|
||
|
# new copy of the list, but we need copy and deep copy both to return the
|
||
|
# same instance.
|
||
|
def __deepcopy__(self, memo):
|
||
|
return self
|
||
|
|
||
|
def __copy__(self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
class ConsumingRegressor(RegressorMixin, BaseEstimator):
|
||
|
"""A regressor consuming metadata.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
registry : list, default=None
|
||
|
If a list, the estimator will append itself to the list in order to have
|
||
|
a reference to the estimator later on. Since that reference is not
|
||
|
required in all tests, registration can be skipped by leaving this value
|
||
|
as None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, registry=None):
|
||
|
self.registry = registry
|
||
|
|
||
|
def partial_fit(self, X, y, sample_weight="default", metadata="default"):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(
|
||
|
self, "partial_fit", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return self
|
||
|
|
||
|
def fit(self, X, y, sample_weight="default", metadata="default"):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(
|
||
|
self, "fit", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, y=None, sample_weight="default", metadata="default"):
|
||
|
record_metadata_not_default(
|
||
|
self, "predict", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return np.zeros(shape=(len(X),))
|
||
|
|
||
|
def score(self, X, y, sample_weight="default", metadata="default"):
|
||
|
record_metadata_not_default(
|
||
|
self, "score", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return 1
|
||
|
|
||
|
|
||
|
class NonConsumingClassifier(ClassifierMixin, BaseEstimator):
|
||
|
"""A classifier which accepts no metadata on any method."""
|
||
|
|
||
|
def __init__(self, alpha=0.0):
|
||
|
self.alpha = alpha
|
||
|
|
||
|
def fit(self, X, y):
|
||
|
self.classes_ = np.unique(y)
|
||
|
return self
|
||
|
|
||
|
def partial_fit(self, X, y, classes=None):
|
||
|
return self
|
||
|
|
||
|
def decision_function(self, X):
|
||
|
return self.predict(X)
|
||
|
|
||
|
def predict(self, X):
|
||
|
y_pred = np.empty(shape=(len(X),))
|
||
|
y_pred[: len(X) // 2] = 0
|
||
|
y_pred[len(X) // 2 :] = 1
|
||
|
return y_pred
|
||
|
|
||
|
|
||
|
class NonConsumingRegressor(RegressorMixin, BaseEstimator):
|
||
|
"""A classifier which accepts no metadata on any method."""
|
||
|
|
||
|
def fit(self, X, y):
|
||
|
return self
|
||
|
|
||
|
def partial_fit(self, X, y):
|
||
|
return self
|
||
|
|
||
|
def predict(self, X):
|
||
|
return np.ones(len(X)) # pragma: no cover
|
||
|
|
||
|
|
||
|
class ConsumingClassifier(ClassifierMixin, BaseEstimator):
|
||
|
"""A classifier consuming metadata.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
registry : list, default=None
|
||
|
If a list, the estimator will append itself to the list in order to have
|
||
|
a reference to the estimator later on. Since that reference is not
|
||
|
required in all tests, registration can be skipped by leaving this value
|
||
|
as None.
|
||
|
|
||
|
alpha : float, default=0
|
||
|
This parameter is only used to test the ``*SearchCV`` objects, and
|
||
|
doesn't do anything.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, registry=None, alpha=0.0):
|
||
|
self.alpha = alpha
|
||
|
self.registry = registry
|
||
|
|
||
|
def partial_fit(
|
||
|
self, X, y, classes=None, sample_weight="default", metadata="default"
|
||
|
):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(
|
||
|
self, "partial_fit", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
_check_partial_fit_first_call(self, classes)
|
||
|
return self
|
||
|
|
||
|
def fit(self, X, y, sample_weight="default", metadata="default"):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(
|
||
|
self, "fit", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
|
||
|
self.classes_ = np.unique(y)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, sample_weight="default", metadata="default"):
|
||
|
record_metadata_not_default(
|
||
|
self, "predict", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
y_score = np.empty(shape=(len(X),), dtype="int8")
|
||
|
y_score[len(X) // 2 :] = 0
|
||
|
y_score[: len(X) // 2] = 1
|
||
|
return y_score
|
||
|
|
||
|
def predict_proba(self, X, sample_weight="default", metadata="default"):
|
||
|
record_metadata_not_default(
|
||
|
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
y_proba = np.empty(shape=(len(X), 2))
|
||
|
y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0])
|
||
|
y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0])
|
||
|
return y_proba
|
||
|
|
||
|
def predict_log_proba(self, X, sample_weight="default", metadata="default"):
|
||
|
pass # pragma: no cover
|
||
|
|
||
|
# uncomment when needed
|
||
|
# record_metadata_not_default(
|
||
|
# self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata
|
||
|
# )
|
||
|
# return np.zeros(shape=(len(X), 2))
|
||
|
|
||
|
def decision_function(self, X, sample_weight="default", metadata="default"):
|
||
|
record_metadata_not_default(
|
||
|
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
y_score = np.empty(shape=(len(X),))
|
||
|
y_score[len(X) // 2 :] = 0
|
||
|
y_score[: len(X) // 2] = 1
|
||
|
return y_score
|
||
|
|
||
|
# uncomment when needed
|
||
|
# def score(self, X, y, sample_weight="default", metadata="default"):
|
||
|
# record_metadata_not_default(
|
||
|
# self, "score", sample_weight=sample_weight, metadata=metadata
|
||
|
# )
|
||
|
# return 1
|
||
|
|
||
|
|
||
|
class ConsumingTransformer(TransformerMixin, BaseEstimator):
|
||
|
"""A transformer which accepts metadata on fit and transform.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
registry : list, default=None
|
||
|
If a list, the estimator will append itself to the list in order to have
|
||
|
a reference to the estimator later on. Since that reference is not
|
||
|
required in all tests, registration can be skipped by leaving this value
|
||
|
as None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, registry=None):
|
||
|
self.registry = registry
|
||
|
|
||
|
def fit(self, X, y=None, sample_weight=None, metadata=None):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(
|
||
|
self, "fit", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, sample_weight=None, metadata=None):
|
||
|
record_metadata(
|
||
|
self, "transform", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return X
|
||
|
|
||
|
def fit_transform(self, X, y, sample_weight=None, metadata=None):
|
||
|
# implementing ``fit_transform`` is necessary since
|
||
|
# ``TransformerMixin.fit_transform`` doesn't route any metadata to
|
||
|
# ``transform``, while here we want ``transform`` to receive
|
||
|
# ``sample_weight`` and ``metadata``.
|
||
|
record_metadata(
|
||
|
self, "fit_transform", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform(
|
||
|
X, sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
|
||
|
def inverse_transform(self, X, sample_weight=None, metadata=None):
|
||
|
record_metadata(
|
||
|
self, "inverse_transform", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return X
|
||
|
|
||
|
|
||
|
class ConsumingNoFitTransformTransformer(BaseEstimator):
|
||
|
"""A metadata consuming transformer that doesn't inherit from
|
||
|
TransformerMixin, and thus doesn't implement `fit_transform`. Note that
|
||
|
TransformerMixin's `fit_transform` doesn't route metadata to `transform`."""
|
||
|
|
||
|
def __init__(self, registry=None):
|
||
|
self.registry = registry
|
||
|
|
||
|
def fit(self, X, y=None, sample_weight=None, metadata=None):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata)
|
||
|
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, sample_weight=None, metadata=None):
|
||
|
record_metadata(
|
||
|
self, "transform", sample_weight=sample_weight, metadata=metadata
|
||
|
)
|
||
|
return X
|
||
|
|
||
|
|
||
|
class ConsumingScorer(_Scorer):
|
||
|
def __init__(self, registry=None):
|
||
|
super().__init__(
|
||
|
score_func=mean_squared_error, sign=1, kwargs={}, response_method="predict"
|
||
|
)
|
||
|
self.registry = registry
|
||
|
|
||
|
def _score(self, method_caller, clf, X, y, **kwargs):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(self, "score", **kwargs)
|
||
|
|
||
|
sample_weight = kwargs.get("sample_weight", None)
|
||
|
return super()._score(method_caller, clf, X, y, sample_weight=sample_weight)
|
||
|
|
||
|
|
||
|
class ConsumingSplitter(GroupsConsumerMixin, BaseCrossValidator):
|
||
|
def __init__(self, registry=None):
|
||
|
self.registry = registry
|
||
|
|
||
|
def split(self, X, y=None, groups="default", metadata="default"):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata_not_default(self, "split", groups=groups, metadata=metadata)
|
||
|
|
||
|
split_index = len(X) // 2
|
||
|
train_indices = list(range(0, split_index))
|
||
|
test_indices = list(range(split_index, len(X)))
|
||
|
yield test_indices, train_indices
|
||
|
yield train_indices, test_indices
|
||
|
|
||
|
def get_n_splits(self, X=None, y=None, groups=None, metadata=None):
|
||
|
return 2
|
||
|
|
||
|
def _iter_test_indices(self, X=None, y=None, groups=None):
|
||
|
split_index = len(X) // 2
|
||
|
train_indices = list(range(0, split_index))
|
||
|
test_indices = list(range(split_index, len(X)))
|
||
|
yield test_indices
|
||
|
yield train_indices
|
||
|
|
||
|
|
||
|
class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
|
||
|
"""A meta-regressor which is only a router."""
|
||
|
|
||
|
def __init__(self, estimator):
|
||
|
self.estimator = estimator
|
||
|
|
||
|
def fit(self, X, y, **fit_params):
|
||
|
params = process_routing(self, "fit", **fit_params)
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = MetadataRouter(owner=self.__class__.__name__).add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
|
||
|
class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
|
||
|
"""A meta-regressor which is also a consumer."""
|
||
|
|
||
|
def __init__(self, estimator, registry=None):
|
||
|
self.estimator = estimator
|
||
|
self.registry = registry
|
||
|
|
||
|
def fit(self, X, y, sample_weight=None, **fit_params):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata(self, "fit", sample_weight=sample_weight)
|
||
|
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
|
||
|
return self
|
||
|
|
||
|
def predict(self, X, **predict_params):
|
||
|
params = process_routing(self, "predict", **predict_params)
|
||
|
return self.estimator_.predict(X, **params.estimator.predict)
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = (
|
||
|
MetadataRouter(owner=self.__class__.__name__)
|
||
|
.add_self_request(self)
|
||
|
.add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping()
|
||
|
.add(caller="fit", callee="fit")
|
||
|
.add(caller="predict", callee="predict"),
|
||
|
)
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
|
||
|
class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
|
||
|
"""A meta-estimator which also consumes sample_weight itself in ``fit``."""
|
||
|
|
||
|
def __init__(self, estimator, registry=None):
|
||
|
self.estimator = estimator
|
||
|
self.registry = registry
|
||
|
|
||
|
def fit(self, X, y, sample_weight=None, **kwargs):
|
||
|
if self.registry is not None:
|
||
|
self.registry.append(self)
|
||
|
|
||
|
record_metadata(self, "fit", sample_weight=sample_weight)
|
||
|
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
|
||
|
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
|
||
|
return self
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
router = (
|
||
|
MetadataRouter(owner=self.__class__.__name__)
|
||
|
.add_self_request(self)
|
||
|
.add(
|
||
|
estimator=self.estimator,
|
||
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
||
|
)
|
||
|
)
|
||
|
return router
|
||
|
|
||
|
|
||
|
class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator):
|
||
|
"""A simple meta-transformer."""
|
||
|
|
||
|
def __init__(self, transformer):
|
||
|
self.transformer = transformer
|
||
|
|
||
|
def fit(self, X, y=None, **fit_params):
|
||
|
params = process_routing(self, "fit", **fit_params)
|
||
|
self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
|
||
|
return self
|
||
|
|
||
|
def transform(self, X, y=None, **transform_params):
|
||
|
params = process_routing(self, "transform", **transform_params)
|
||
|
return self.transformer_.transform(X, **params.transformer.transform)
|
||
|
|
||
|
def get_metadata_routing(self):
|
||
|
return MetadataRouter(owner=self.__class__.__name__).add(
|
||
|
transformer=self.transformer,
|
||
|
method_mapping=MethodMapping()
|
||
|
.add(caller="fit", callee="fit")
|
||
|
.add(caller="transform", callee="transform"),
|
||
|
)
|