from functools import partial import numpy as np from numpy.testing import assert_array_equal from sklearn.base import ( BaseEstimator, ClassifierMixin, MetaEstimatorMixin, RegressorMixin, TransformerMixin, clone, ) from sklearn.metrics._scorer import _Scorer, mean_squared_error from sklearn.model_selection import BaseCrossValidator from sklearn.model_selection._split import GroupsConsumerMixin from sklearn.utils._metadata_requests import ( SIMPLE_METHODS, ) from sklearn.utils.metadata_routing import ( MetadataRouter, MethodMapping, process_routing, ) from sklearn.utils.multiclass import _check_partial_fit_first_call def record_metadata(obj, method, record_default=True, **kwargs): """Utility function to store passed metadata to a method. If record_default is False, kwargs whose values are "default" are skipped. This is so that checks on keyword arguments whose default was not changed are skipped. """ if not hasattr(obj, "_records"): obj._records = {} if not record_default: kwargs = { key: val for key, val in kwargs.items() if not isinstance(val, str) or (val != "default") } obj._records[method] = kwargs def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs): """Check whether the expected metadata is passed to the object's method. Parameters ---------- obj : estimator object sub-estimator to check routed params for method : str sub-estimator's method where metadata is routed to split_params : tuple, default=empty specifies any parameters which are to be checked as being a subset of the original values **kwargs : dict passed metadata """ records = getattr(obj, "_records", dict()).get(method, dict()) assert set(kwargs.keys()) == set( records.keys() ), f"Expected {kwargs.keys()} vs {records.keys()}" for key, value in kwargs.items(): recorded_value = records[key] # The following condition is used to check for any specified parameters # being a subset of the original values if key in split_params and recorded_value is not None: assert np.isin(recorded_value, value).all() else: if isinstance(recorded_value, np.ndarray): assert_array_equal(recorded_value, value) else: assert recorded_value is value, f"Expected {recorded_value} vs {value}" record_metadata_not_default = partial(record_metadata, record_default=False) def assert_request_is_empty(metadata_request, exclude=None): """Check if a metadata request dict is empty. One can exclude a method or a list of methods from the check using the ``exclude`` parameter. If metadata_request is a MetadataRouter, then ``exclude`` can be of the form ``{"object" : [method, ...]}``. """ if isinstance(metadata_request, MetadataRouter): for name, route_mapping in metadata_request: if exclude is not None and name in exclude: _exclude = exclude[name] else: _exclude = None assert_request_is_empty(route_mapping.router, exclude=_exclude) return exclude = [] if exclude is None else exclude for method in SIMPLE_METHODS: if method in exclude: continue mmr = getattr(metadata_request, method) props = [ prop for prop, alias in mmr.requests.items() if isinstance(alias, str) or alias is not None ] assert not props def assert_request_equal(request, dictionary): for method, requests in dictionary.items(): mmr = getattr(request, method) assert mmr.requests == requests empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary] for method in empty_methods: assert not len(getattr(request, method).requests) class _Registry(list): # This list is used to get a reference to the sub-estimators, which are not # necessarily stored on the metaestimator. We need to override __deepcopy__ # because the sub-estimators are probably cloned, which would result in a # new copy of the list, but we need copy and deep copy both to return the # same instance. def __deepcopy__(self, memo): return self def __copy__(self): return self class ConsumingRegressor(RegressorMixin, BaseEstimator): """A regressor consuming metadata. Parameters ---------- registry : list, default=None If a list, the estimator will append itself to the list in order to have a reference to the estimator later on. Since that reference is not required in all tests, registration can be skipped by leaving this value as None. """ def __init__(self, registry=None): self.registry = registry def partial_fit(self, X, y, sample_weight="default", metadata="default"): if self.registry is not None: self.registry.append(self) record_metadata_not_default( self, "partial_fit", sample_weight=sample_weight, metadata=metadata ) return self def fit(self, X, y, sample_weight="default", metadata="default"): if self.registry is not None: self.registry.append(self) record_metadata_not_default( self, "fit", sample_weight=sample_weight, metadata=metadata ) return self def predict(self, X, y=None, sample_weight="default", metadata="default"): record_metadata_not_default( self, "predict", sample_weight=sample_weight, metadata=metadata ) return np.zeros(shape=(len(X),)) def score(self, X, y, sample_weight="default", metadata="default"): record_metadata_not_default( self, "score", sample_weight=sample_weight, metadata=metadata ) return 1 class NonConsumingClassifier(ClassifierMixin, BaseEstimator): """A classifier which accepts no metadata on any method.""" def __init__(self, alpha=0.0): self.alpha = alpha def fit(self, X, y): self.classes_ = np.unique(y) return self def partial_fit(self, X, y, classes=None): return self def decision_function(self, X): return self.predict(X) def predict(self, X): y_pred = np.empty(shape=(len(X),)) y_pred[: len(X) // 2] = 0 y_pred[len(X) // 2 :] = 1 return y_pred class NonConsumingRegressor(RegressorMixin, BaseEstimator): """A classifier which accepts no metadata on any method.""" def fit(self, X, y): return self def partial_fit(self, X, y): return self def predict(self, X): return np.ones(len(X)) # pragma: no cover class ConsumingClassifier(ClassifierMixin, BaseEstimator): """A classifier consuming metadata. Parameters ---------- registry : list, default=None If a list, the estimator will append itself to the list in order to have a reference to the estimator later on. Since that reference is not required in all tests, registration can be skipped by leaving this value as None. alpha : float, default=0 This parameter is only used to test the ``*SearchCV`` objects, and doesn't do anything. """ def __init__(self, registry=None, alpha=0.0): self.alpha = alpha self.registry = registry def partial_fit( self, X, y, classes=None, sample_weight="default", metadata="default" ): if self.registry is not None: self.registry.append(self) record_metadata_not_default( self, "partial_fit", sample_weight=sample_weight, metadata=metadata ) _check_partial_fit_first_call(self, classes) return self def fit(self, X, y, sample_weight="default", metadata="default"): if self.registry is not None: self.registry.append(self) record_metadata_not_default( self, "fit", sample_weight=sample_weight, metadata=metadata ) self.classes_ = np.unique(y) return self def predict(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, "predict", sample_weight=sample_weight, metadata=metadata ) y_score = np.empty(shape=(len(X),), dtype="int8") y_score[len(X) // 2 :] = 0 y_score[: len(X) // 2] = 1 return y_score def predict_proba(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, "predict_proba", sample_weight=sample_weight, metadata=metadata ) y_proba = np.empty(shape=(len(X), 2)) y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) return y_proba def predict_log_proba(self, X, sample_weight="default", metadata="default"): pass # pragma: no cover # uncomment when needed # record_metadata_not_default( # self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata # ) # return np.zeros(shape=(len(X), 2)) def decision_function(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, "predict_proba", sample_weight=sample_weight, metadata=metadata ) y_score = np.empty(shape=(len(X),)) y_score[len(X) // 2 :] = 0 y_score[: len(X) // 2] = 1 return y_score # uncomment when needed # def score(self, X, y, sample_weight="default", metadata="default"): # record_metadata_not_default( # self, "score", sample_weight=sample_weight, metadata=metadata # ) # return 1 class ConsumingTransformer(TransformerMixin, BaseEstimator): """A transformer which accepts metadata on fit and transform. Parameters ---------- registry : list, default=None If a list, the estimator will append itself to the list in order to have a reference to the estimator later on. Since that reference is not required in all tests, registration can be skipped by leaving this value as None. """ def __init__(self, registry=None): self.registry = registry def fit(self, X, y=None, sample_weight=None, metadata=None): if self.registry is not None: self.registry.append(self) record_metadata_not_default( self, "fit", sample_weight=sample_weight, metadata=metadata ) return self def transform(self, X, sample_weight=None, metadata=None): record_metadata( self, "transform", sample_weight=sample_weight, metadata=metadata ) return X def fit_transform(self, X, y, sample_weight=None, metadata=None): # implementing ``fit_transform`` is necessary since # ``TransformerMixin.fit_transform`` doesn't route any metadata to # ``transform``, while here we want ``transform`` to receive # ``sample_weight`` and ``metadata``. record_metadata( self, "fit_transform", sample_weight=sample_weight, metadata=metadata ) return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform( X, sample_weight=sample_weight, metadata=metadata ) def inverse_transform(self, X, sample_weight=None, metadata=None): record_metadata( self, "inverse_transform", sample_weight=sample_weight, metadata=metadata ) return X class ConsumingNoFitTransformTransformer(BaseEstimator): """A metadata consuming transformer that doesn't inherit from TransformerMixin, and thus doesn't implement `fit_transform`. Note that TransformerMixin's `fit_transform` doesn't route metadata to `transform`.""" def __init__(self, registry=None): self.registry = registry def fit(self, X, y=None, sample_weight=None, metadata=None): if self.registry is not None: self.registry.append(self) record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata) return self def transform(self, X, sample_weight=None, metadata=None): record_metadata( self, "transform", sample_weight=sample_weight, metadata=metadata ) return X class ConsumingScorer(_Scorer): def __init__(self, registry=None): super().__init__( score_func=mean_squared_error, sign=1, kwargs={}, response_method="predict" ) self.registry = registry def _score(self, method_caller, clf, X, y, **kwargs): if self.registry is not None: self.registry.append(self) record_metadata_not_default(self, "score", **kwargs) sample_weight = kwargs.get("sample_weight", None) return super()._score(method_caller, clf, X, y, sample_weight=sample_weight) class ConsumingSplitter(GroupsConsumerMixin, BaseCrossValidator): def __init__(self, registry=None): self.registry = registry def split(self, X, y=None, groups="default", metadata="default"): if self.registry is not None: self.registry.append(self) record_metadata_not_default(self, "split", groups=groups, metadata=metadata) split_index = len(X) // 2 train_indices = list(range(0, split_index)) test_indices = list(range(split_index, len(X))) yield test_indices, train_indices yield train_indices, test_indices def get_n_splits(self, X=None, y=None, groups=None, metadata=None): return 2 def _iter_test_indices(self, X=None, y=None, groups=None): split_index = len(X) // 2 train_indices = list(range(0, split_index)) test_indices = list(range(split_index, len(X))) yield test_indices yield train_indices class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): """A meta-regressor which is only a router.""" def __init__(self, estimator): self.estimator = estimator def fit(self, X, y, **fit_params): params = process_routing(self, "fit", **fit_params) self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) def get_metadata_routing(self): router = MetadataRouter(owner=self.__class__.__name__).add( estimator=self.estimator, method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) return router class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): """A meta-regressor which is also a consumer.""" def __init__(self, estimator, registry=None): self.estimator = estimator self.registry = registry def fit(self, X, y, sample_weight=None, **fit_params): if self.registry is not None: self.registry.append(self) record_metadata(self, "fit", sample_weight=sample_weight) params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params) self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) return self def predict(self, X, **predict_params): params = process_routing(self, "predict", **predict_params) return self.estimator_.predict(X, **params.estimator.predict) def get_metadata_routing(self): router = ( MetadataRouter(owner=self.__class__.__name__) .add_self_request(self) .add( estimator=self.estimator, method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="predict", callee="predict"), ) ) return router class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): """A meta-estimator which also consumes sample_weight itself in ``fit``.""" def __init__(self, estimator, registry=None): self.estimator = estimator self.registry = registry def fit(self, X, y, sample_weight=None, **kwargs): if self.registry is not None: self.registry.append(self) record_metadata(self, "fit", sample_weight=sample_weight) params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs) self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) return self def get_metadata_routing(self): router = ( MetadataRouter(owner=self.__class__.__name__) .add_self_request(self) .add( estimator=self.estimator, method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) ) return router class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): """A simple meta-transformer.""" def __init__(self, transformer): self.transformer = transformer def fit(self, X, y=None, **fit_params): params = process_routing(self, "fit", **fit_params) self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) return self def transform(self, X, y=None, **transform_params): params = process_routing(self, "transform", **transform_params) return self.transformer_.transform(X, **params.transformer.transform) def get_metadata_routing(self): return MetadataRouter(owner=self.__class__.__name__).add( transformer=self.transformer, method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="transform", callee="transform"), )