""" Metadata Routing Utility Tests """ # Author: Adrin Jalali # License: BSD 3 clause import re import numpy as np import pytest from sklearn import config_context from sklearn.base import ( BaseEstimator, clone, ) from sklearn.exceptions import UnsetMetadataPassedError from sklearn.linear_model import LinearRegression from sklearn.pipeline import Pipeline from sklearn.tests.metadata_routing_common import ( ConsumingClassifier, ConsumingRegressor, ConsumingTransformer, MetaRegressor, MetaTransformer, NonConsumingClassifier, WeightedMetaClassifier, WeightedMetaRegressor, _Registry, assert_request_equal, assert_request_is_empty, check_recorded_metadata, ) from sklearn.utils import metadata_routing from sklearn.utils._metadata_requests import ( COMPOSITE_METHODS, METHODS, SIMPLE_METHODS, MethodMetadataRequest, MethodPair, _MetadataRequester, request_is_alias, request_is_valid, ) from sklearn.utils.metadata_routing import ( MetadataRequest, MetadataRouter, MethodMapping, _RoutingNotSupportedMixin, get_routing_for_object, process_routing, ) from sklearn.utils.validation import check_is_fitted rng = np.random.RandomState(42) N, M = 100, 4 X = rng.rand(N, M) y = rng.randint(0, 2, size=N) my_groups = rng.randint(0, 10, size=N) my_weights = rng.rand(N) my_other_weights = rng.rand(N) @pytest.fixture(autouse=True) def enable_slep006(): """Enable SLEP006 for all tests.""" with config_context(enable_metadata_routing=True): yield class SimplePipeline(BaseEstimator): """A very simple pipeline, assuming the last step is always a predictor. Parameters ---------- steps : iterable of objects An iterable of transformers with the last step being a predictor. """ def __init__(self, steps): self.steps = steps def fit(self, X, y, **fit_params): self.steps_ = [] params = process_routing(self, "fit", **fit_params) X_transformed = X for i, step in enumerate(self.steps[:-1]): transformer = clone(step).fit( X_transformed, y, **params.get(f"step_{i}").fit ) self.steps_.append(transformer) X_transformed = transformer.transform( X_transformed, **params.get(f"step_{i}").transform ) self.steps_.append( clone(self.steps[-1]).fit(X_transformed, y, **params.predictor.fit) ) return self def predict(self, X, **predict_params): check_is_fitted(self) X_transformed = X params = process_routing(self, "predict", **predict_params) for i, step in enumerate(self.steps_[:-1]): X_transformed = step.transform(X, **params.get(f"step_{i}").transform) return self.steps_[-1].predict(X_transformed, **params.predictor.predict) def get_metadata_routing(self): router = MetadataRouter(owner=self.__class__.__name__) for i, step in enumerate(self.steps[:-1]): router.add( **{f"step_{i}": step}, method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="fit", callee="transform") .add(caller="predict", callee="transform"), ) router.add( predictor=self.steps[-1], method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="predict", callee="predict"), ) return router def test_assert_request_is_empty(): requests = MetadataRequest(owner="test") assert_request_is_empty(requests) requests.fit.add_request(param="foo", alias=None) # this should still work, since None is the default value assert_request_is_empty(requests) requests.fit.add_request(param="bar", alias="value") with pytest.raises(AssertionError): # now requests is no more empty assert_request_is_empty(requests) # but one can exclude a method assert_request_is_empty(requests, exclude="fit") requests.score.add_request(param="carrot", alias=True) with pytest.raises(AssertionError): # excluding `fit` is not enough assert_request_is_empty(requests, exclude="fit") # and excluding both fit and score would avoid an exception assert_request_is_empty(requests, exclude=["fit", "score"]) # test if a router is empty assert_request_is_empty( MetadataRouter(owner="test") .add_self_request(WeightedMetaRegressor(estimator=None)) .add( estimator=ConsumingRegressor(), method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) ) @pytest.mark.parametrize( "estimator", [ ConsumingClassifier(registry=_Registry()), ConsumingRegressor(registry=_Registry()), ConsumingTransformer(registry=_Registry()), WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()), WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()), ], ) def test_estimator_puts_self_in_registry(estimator): """Check that an estimator puts itself in the registry upon fit.""" estimator.fit(X, y) assert estimator in estimator.registry @pytest.mark.parametrize( "val, res", [ (False, False), (True, False), (None, False), ("$UNUSED$", False), ("$WARN$", False), ("invalid-input", False), ("valid_arg", True), ], ) def test_request_type_is_alias(val, res): # Test request_is_alias assert request_is_alias(val) == res @pytest.mark.parametrize( "val, res", [ (False, True), (True, True), (None, True), ("$UNUSED$", True), ("$WARN$", True), ("invalid-input", False), ("alias_arg", False), ], ) def test_request_type_is_valid(val, res): # Test request_is_valid assert request_is_valid(val) == res def test_default_requests(): class OddEstimator(BaseEstimator): __metadata_request__fit = { # set a different default request "sample_weight": True } # type: ignore odd_request = get_routing_for_object(OddEstimator()) assert odd_request.fit.requests == {"sample_weight": True} # check other test estimators assert not len(get_routing_for_object(NonConsumingClassifier()).fit.requests) assert_request_is_empty(NonConsumingClassifier().get_metadata_routing()) trs_request = get_routing_for_object(ConsumingTransformer()) assert trs_request.fit.requests == { "sample_weight": None, "metadata": None, } assert trs_request.transform.requests == {"metadata": None, "sample_weight": None} assert_request_is_empty(trs_request) est_request = get_routing_for_object(ConsumingClassifier()) assert est_request.fit.requests == { "sample_weight": None, "metadata": None, } assert_request_is_empty(est_request) def test_default_request_override(): """Test that default requests are correctly overridden regardless of the ASCII order of the class names, hence testing small and capital letter class name starts. Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28430 """ class Base(BaseEstimator): __metadata_request__split = {"groups": True} class class_1(Base): __metadata_request__split = {"groups": "sample_domain"} class Class_1(Base): __metadata_request__split = {"groups": "sample_domain"} assert_request_equal( class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}} ) assert_request_equal( Class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}} ) def test_process_routing_invalid_method(): with pytest.raises(TypeError, match="Can only route and process input"): process_routing(ConsumingClassifier(), "invalid_method", groups=my_groups) def test_process_routing_invalid_object(): class InvalidObject: pass with pytest.raises(AttributeError, match="either implement the routing method"): process_routing(InvalidObject(), "fit", groups=my_groups) @pytest.mark.parametrize("method", METHODS) @pytest.mark.parametrize("default", [None, "default", []]) def test_process_routing_empty_params_get_with_default(method, default): empty_params = {} routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params) # Behaviour should be an empty dictionary returned for each method when retrieved. params_for_method = routed_params[method] assert isinstance(params_for_method, dict) assert set(params_for_method.keys()) == set(METHODS) # No default to `get` should be equivalent to the default default_params_for_method = routed_params.get(method, default=default) assert default_params_for_method == params_for_method def test_simple_metadata_routing(): # Tests that metadata is properly routed # The underlying estimator doesn't accept or request metadata clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) clf.fit(X, y) # Meta-estimator consumes sample_weight, but doesn't forward it to the underlying # estimator clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) clf.fit(X, y, sample_weight=my_weights) # If the estimator accepts the metadata but doesn't explicitly say it doesn't # need it, there's an error clf = WeightedMetaClassifier(estimator=ConsumingClassifier()) err_message = ( "[sample_weight] are passed but are not explicitly set as requested or" " not requested for ConsumingClassifier.fit" ) with pytest.raises(ValueError, match=re.escape(err_message)): clf.fit(X, y, sample_weight=my_weights) # Explicitly saying the estimator doesn't need it, makes the error go away, # because in this case `WeightedMetaClassifier` consumes `sample_weight`. If # there was no consumer of sample_weight, passing it would result in an # error. clf = WeightedMetaClassifier( estimator=ConsumingClassifier().set_fit_request(sample_weight=False) ) # this doesn't raise since WeightedMetaClassifier itself is a consumer, # and passing metadata to the consumer directly is fine regardless of its # metadata_request values. clf.fit(X, y, sample_weight=my_weights) check_recorded_metadata(clf.estimator_, "fit") # Requesting a metadata will make the meta-estimator forward it correctly clf = WeightedMetaClassifier( estimator=ConsumingClassifier().set_fit_request(sample_weight=True) ) clf.fit(X, y, sample_weight=my_weights) check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights) # And requesting it with an alias clf = WeightedMetaClassifier( estimator=ConsumingClassifier().set_fit_request( sample_weight="alternative_weight" ) ) clf.fit(X, y, alternative_weight=my_weights) check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights) def test_nested_routing(): # check if metadata is routed in a nested routing situation. pipeline = SimplePipeline( [ MetaTransformer( transformer=ConsumingTransformer() .set_fit_request(metadata=True, sample_weight=False) .set_transform_request(sample_weight=True, metadata=False) ), WeightedMetaRegressor( estimator=ConsumingRegressor() .set_fit_request(sample_weight="inner_weights", metadata=False) .set_predict_request(sample_weight=False) ).set_fit_request(sample_weight="outer_weights"), ] ) w1, w2, w3 = [1], [2], [3] pipeline.fit( X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3 ) check_recorded_metadata( pipeline.steps_[0].transformer_, "fit", metadata=my_groups, sample_weight=None ) check_recorded_metadata( pipeline.steps_[0].transformer_, "transform", sample_weight=w1, metadata=None ) check_recorded_metadata(pipeline.steps_[1], "fit", sample_weight=w2) check_recorded_metadata(pipeline.steps_[1].estimator_, "fit", sample_weight=w3) pipeline.predict(X, sample_weight=w3) check_recorded_metadata( pipeline.steps_[0].transformer_, "transform", sample_weight=w3, metadata=None ) def test_nested_routing_conflict(): # check if an error is raised if there's a conflict between keys pipeline = SimplePipeline( [ MetaTransformer( transformer=ConsumingTransformer() .set_fit_request(metadata=True, sample_weight=False) .set_transform_request(sample_weight=True) ), WeightedMetaRegressor( estimator=ConsumingRegressor().set_fit_request(sample_weight=True) ).set_fit_request(sample_weight="outer_weights"), ] ) w1, w2 = [1], [2] with pytest.raises( ValueError, match=( re.escape( "In WeightedMetaRegressor, there is a conflict on sample_weight between" " what is requested for this estimator and what is requested by its" " children. You can resolve this conflict by using an alias for the" " child estimator(s) requested metadata." ) ), ): pipeline.fit(X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2) def test_invalid_metadata(): # check that passing wrong metadata raises an error trs = MetaTransformer( transformer=ConsumingTransformer().set_transform_request(sample_weight=True) ) with pytest.raises( TypeError, match=(re.escape("transform got unexpected argument(s) {'other_param'}")), ): trs.fit(X, y).transform(X, other_param=my_weights) # passing a metadata which is not requested by any estimator should also raise trs = MetaTransformer( transformer=ConsumingTransformer().set_transform_request(sample_weight=False) ) with pytest.raises( TypeError, match=(re.escape("transform got unexpected argument(s) {'sample_weight'}")), ): trs.fit(X, y).transform(X, sample_weight=my_weights) def test_get_metadata_routing(): class TestDefaultsBadMethodName(_MetadataRequester): __metadata_request__fit = { "sample_weight": None, "my_param": None, } __metadata_request__score = { "sample_weight": None, "my_param": True, "my_other_param": None, } # this will raise an error since we don't understand "other_method" as a method __metadata_request__other_method = {"my_param": True} class TestDefaults(_MetadataRequester): __metadata_request__fit = { "sample_weight": None, "my_other_param": None, } __metadata_request__score = { "sample_weight": None, "my_param": True, "my_other_param": None, } __metadata_request__predict = {"my_param": True} with pytest.raises( AttributeError, match="'MetadataRequest' object has no attribute 'other_method'" ): TestDefaultsBadMethodName().get_metadata_routing() expected = { "score": { "my_param": True, "my_other_param": None, "sample_weight": None, }, "fit": { "my_other_param": None, "sample_weight": None, }, "predict": {"my_param": True}, } assert_request_equal(TestDefaults().get_metadata_routing(), expected) est = TestDefaults().set_score_request(my_param="other_param") expected = { "score": { "my_param": "other_param", "my_other_param": None, "sample_weight": None, }, "fit": { "my_other_param": None, "sample_weight": None, }, "predict": {"my_param": True}, } assert_request_equal(est.get_metadata_routing(), expected) est = TestDefaults().set_fit_request(sample_weight=True) expected = { "score": { "my_param": True, "my_other_param": None, "sample_weight": None, }, "fit": { "my_other_param": None, "sample_weight": True, }, "predict": {"my_param": True}, } assert_request_equal(est.get_metadata_routing(), expected) def test_setting_default_requests(): # Test _get_default_requests method test_cases = dict() class ExplicitRequest(BaseEstimator): # `fit` doesn't accept `props` explicitly, but we want to request it __metadata_request__fit = {"prop": None} def fit(self, X, y, **kwargs): return self test_cases[ExplicitRequest] = {"prop": None} class ExplicitRequestOverwrite(BaseEstimator): # `fit` explicitly accepts `props`, but we want to change the default # request value from None to True __metadata_request__fit = {"prop": True} def fit(self, X, y, prop=None, **kwargs): return self test_cases[ExplicitRequestOverwrite] = {"prop": True} class ImplicitRequest(BaseEstimator): # `fit` requests `prop` and the default None should be used def fit(self, X, y, prop=None, **kwargs): return self test_cases[ImplicitRequest] = {"prop": None} class ImplicitRequestRemoval(BaseEstimator): # `fit` (in this class or a parent) requests `prop`, but we don't want # it requested at all. __metadata_request__fit = {"prop": metadata_routing.UNUSED} def fit(self, X, y, prop=None, **kwargs): return self test_cases[ImplicitRequestRemoval] = {} for Klass, requests in test_cases.items(): assert get_routing_for_object(Klass()).fit.requests == requests assert_request_is_empty(Klass().get_metadata_routing(), exclude="fit") Klass().fit(None, None) # for coverage def test_removing_non_existing_param_raises(): """Test that removing a metadata using UNUSED which doesn't exist raises.""" class InvalidRequestRemoval(BaseEstimator): # `fit` (in this class or a parent) requests `prop`, but we don't want # it requested at all. __metadata_request__fit = {"prop": metadata_routing.UNUSED} def fit(self, X, y, **kwargs): return self with pytest.raises(ValueError, match="Trying to remove parameter"): InvalidRequestRemoval().get_metadata_routing() def test_method_metadata_request(): mmr = MethodMetadataRequest(owner="test", method="fit") with pytest.raises(ValueError, match="The alias you're setting for"): mmr.add_request(param="foo", alias=1.4) mmr.add_request(param="foo", alias=None) assert mmr.requests == {"foo": None} mmr.add_request(param="foo", alias=False) assert mmr.requests == {"foo": False} mmr.add_request(param="foo", alias=True) assert mmr.requests == {"foo": True} mmr.add_request(param="foo", alias="foo") assert mmr.requests == {"foo": True} mmr.add_request(param="foo", alias="bar") assert mmr.requests == {"foo": "bar"} assert mmr._get_param_names(return_alias=False) == {"foo"} assert mmr._get_param_names(return_alias=True) == {"bar"} def test_get_routing_for_object(): class Consumer(BaseEstimator): __metadata_request__fit = {"prop": None} assert_request_is_empty(get_routing_for_object(None)) assert_request_is_empty(get_routing_for_object(object())) mr = MetadataRequest(owner="test") mr.fit.add_request(param="foo", alias="bar") mr_factory = get_routing_for_object(mr) assert_request_is_empty(mr_factory, exclude="fit") assert mr_factory.fit.requests == {"foo": "bar"} mr = get_routing_for_object(Consumer()) assert_request_is_empty(mr, exclude="fit") assert mr.fit.requests == {"prop": None} def test_metadata_request_consumes_method(): """Test that MetadataRequest().consumes() method works as expected.""" request = MetadataRouter(owner="test") assert request.consumes(method="fit", params={"foo"}) == set() request = MetadataRequest(owner="test") request.fit.add_request(param="foo", alias=True) assert request.consumes(method="fit", params={"foo"}) == {"foo"} request = MetadataRequest(owner="test") request.fit.add_request(param="foo", alias="bar") assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"} def test_metadata_router_consumes_method(): """Test that MetadataRouter().consumes method works as expected.""" # having it here instead of parametrizing the test since `set_fit_request` # is not available while collecting the tests. cases = [ ( WeightedMetaRegressor( estimator=ConsumingRegressor().set_fit_request(sample_weight=True) ), {"sample_weight"}, {"sample_weight"}, ), ( WeightedMetaRegressor( estimator=ConsumingRegressor().set_fit_request( sample_weight="my_weights" ) ), {"my_weights", "sample_weight"}, {"my_weights"}, ), ] for obj, input, output in cases: assert obj.get_metadata_routing().consumes(method="fit", params=input) == output def test_metaestimator_warnings(): class WeightedMetaRegressorWarn(WeightedMetaRegressor): __metadata_request__fit = {"sample_weight": metadata_routing.WARN} with pytest.warns( UserWarning, match="Support for .* has recently been added to this class" ): WeightedMetaRegressorWarn( estimator=LinearRegression().set_fit_request(sample_weight=False) ).fit(X, y, sample_weight=my_weights) def test_estimator_warnings(): class ConsumingRegressorWarn(ConsumingRegressor): __metadata_request__fit = {"sample_weight": metadata_routing.WARN} with pytest.warns( UserWarning, match="Support for .* has recently been added to this class" ): MetaRegressor(estimator=ConsumingRegressorWarn()).fit( X, y, sample_weight=my_weights ) @pytest.mark.parametrize( "obj, string", [ ( MethodMetadataRequest(owner="test", method="fit").add_request( param="foo", alias="bar" ), "{'foo': 'bar'}", ), ( MetadataRequest(owner="test"), "{}", ), ( MetadataRouter(owner="test").add( estimator=ConsumingRegressor(), method_mapping=MethodMapping().add(caller="predict", callee="predict"), ), ( "{'estimator': {'mapping': [{'caller': 'predict', 'callee':" " 'predict'}], 'router': {'fit': {'sample_weight': None, 'metadata':" " None}, 'partial_fit': {'sample_weight': None, 'metadata': None}," " 'predict': {'sample_weight': None, 'metadata': None}, 'score':" " {'sample_weight': None, 'metadata': None}}}}" ), ), ], ) def test_string_representations(obj, string): assert str(obj) == string @pytest.mark.parametrize( "obj, method, inputs, err_cls, err_msg", [ ( MethodMapping(), "add", {"caller": "fit", "callee": "invalid"}, ValueError, "Given callee", ), ( MethodMapping(), "add", {"caller": "invalid", "callee": "fit"}, ValueError, "Given caller", ), ( MetadataRouter(owner="test"), "add_self_request", {"obj": MetadataRouter(owner="test")}, ValueError, "Given `obj` is neither a `MetadataRequest` nor does it implement", ), ( ConsumingClassifier(), "set_fit_request", {"invalid": True}, TypeError, "Unexpected args", ), ], ) def test_validations(obj, method, inputs, err_cls, err_msg): with pytest.raises(err_cls, match=err_msg): getattr(obj, method)(**inputs) def test_methodmapping(): mm = ( MethodMapping() .add(caller="fit", callee="transform") .add(caller="fit", callee="fit") ) mm_list = list(mm) assert mm_list[0] == ("fit", "transform") assert mm_list[1] == ("fit", "fit") mm = MethodMapping() for method in METHODS: mm.add(caller=method, callee=method) assert MethodPair(method, method) in mm._routes assert len(mm._routes) == len(METHODS) mm = MethodMapping().add(caller="score", callee="score") assert repr(mm) == "[{'caller': 'score', 'callee': 'score'}]" def test_metadatarouter_add_self_request(): # adding a MetadataRequest as `self` adds a copy request = MetadataRequest(owner="nested") request.fit.add_request(param="param", alias=True) router = MetadataRouter(owner="test").add_self_request(request) assert str(router._self_request) == str(request) # should be a copy, not the same object assert router._self_request is not request # one can add an estimator as self est = ConsumingRegressor().set_fit_request(sample_weight="my_weights") router = MetadataRouter(owner="test").add_self_request(obj=est) assert str(router._self_request) == str(est.get_metadata_routing()) assert router._self_request is not est.get_metadata_routing() # adding a consumer+router as self should only add the consumer part est = WeightedMetaRegressor( estimator=ConsumingRegressor().set_fit_request(sample_weight="nested_weights") ) router = MetadataRouter(owner="test").add_self_request(obj=est) # _get_metadata_request() returns the consumer part of the requests assert str(router._self_request) == str(est._get_metadata_request()) # get_metadata_routing() returns the complete request set, consumer and # router included. assert str(router._self_request) != str(est.get_metadata_routing()) # it should be a copy, not the same object assert router._self_request is not est._get_metadata_request() def test_metadata_routing_add(): # adding one with a string `method_mapping` router = MetadataRouter(owner="test").add( est=ConsumingRegressor().set_fit_request(sample_weight="weights"), method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) assert ( str(router) == "{'est': {'mapping': [{'caller': 'fit', 'callee': 'fit'}], 'router': {'fit':" " {'sample_weight': 'weights', 'metadata': None}, 'partial_fit':" " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" " None, 'metadata': None}, 'score': {'sample_weight': None, 'metadata':" " None}}}}" ) # adding one with an instance of MethodMapping router = MetadataRouter(owner="test").add( method_mapping=MethodMapping().add(caller="fit", callee="score"), est=ConsumingRegressor().set_score_request(sample_weight=True), ) assert ( str(router) == "{'est': {'mapping': [{'caller': 'fit', 'callee': 'score'}], 'router':" " {'fit': {'sample_weight': None, 'metadata': None}, 'partial_fit':" " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" " None, 'metadata': None}, 'score': {'sample_weight': True, 'metadata':" " None}}}}" ) def test_metadata_routing_get_param_names(): router = ( MetadataRouter(owner="test") .add_self_request( WeightedMetaRegressor(estimator=ConsumingRegressor()).set_fit_request( sample_weight="self_weights" ) ) .add( trs=ConsumingTransformer().set_fit_request( sample_weight="transform_weights" ), method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) ) assert ( str(router) == "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score':" " {'sample_weight': None}}, 'trs': {'mapping': [{'caller': 'fit', 'callee':" " 'fit'}], 'router': {'fit': {'sample_weight': 'transform_weights'," " 'metadata': None}, 'transform': {'sample_weight': None, 'metadata': None}," " 'inverse_transform': {'sample_weight': None, 'metadata': None}}}}" ) assert router._get_param_names( method="fit", return_alias=True, ignore_self_request=False ) == {"transform_weights", "metadata", "self_weights"} # return_alias=False will return original names for "self" assert router._get_param_names( method="fit", return_alias=False, ignore_self_request=False ) == {"sample_weight", "metadata", "transform_weights"} # ignoring self would remove "sample_weight" assert router._get_param_names( method="fit", return_alias=False, ignore_self_request=True ) == {"metadata", "transform_weights"} # return_alias is ignored when ignore_self_request=True assert router._get_param_names( method="fit", return_alias=True, ignore_self_request=True ) == router._get_param_names( method="fit", return_alias=False, ignore_self_request=True ) def test_method_generation(): # Test if all required request methods are generated. # TODO: these test classes can be moved to sklearn.utils._testing once we # have a better idea of what the commonly used classes are. class SimpleEstimator(BaseEstimator): # This class should have no set_{method}_request def fit(self, X, y): pass # pragma: no cover def fit_transform(self, X, y): pass # pragma: no cover def fit_predict(self, X, y): pass # pragma: no cover def partial_fit(self, X, y): pass # pragma: no cover def predict(self, X): pass # pragma: no cover def predict_proba(self, X): pass # pragma: no cover def predict_log_proba(self, X): pass # pragma: no cover def decision_function(self, X): pass # pragma: no cover def score(self, X, y): pass # pragma: no cover def split(self, X, y=None): pass # pragma: no cover def transform(self, X): pass # pragma: no cover def inverse_transform(self, X): pass # pragma: no cover for method in METHODS: assert not hasattr(SimpleEstimator(), f"set_{method}_request") class SimpleEstimator(BaseEstimator): # This class should have every set_{method}_request def fit(self, X, y, sample_weight=None): pass # pragma: no cover def fit_transform(self, X, y, sample_weight=None): pass # pragma: no cover def fit_predict(self, X, y, sample_weight=None): pass # pragma: no cover def partial_fit(self, X, y, sample_weight=None): pass # pragma: no cover def predict(self, X, sample_weight=None): pass # pragma: no cover def predict_proba(self, X, sample_weight=None): pass # pragma: no cover def predict_log_proba(self, X, sample_weight=None): pass # pragma: no cover def decision_function(self, X, sample_weight=None): pass # pragma: no cover def score(self, X, y, sample_weight=None): pass # pragma: no cover def split(self, X, y=None, sample_weight=None): pass # pragma: no cover def transform(self, X, sample_weight=None): pass # pragma: no cover def inverse_transform(self, X, sample_weight=None): pass # pragma: no cover # composite methods shouldn't have a corresponding set method. for method in COMPOSITE_METHODS: assert not hasattr(SimpleEstimator(), f"set_{method}_request") # simple methods should have a corresponding set method. for method in SIMPLE_METHODS: assert hasattr(SimpleEstimator(), f"set_{method}_request") def test_composite_methods(): # Test the behavior and the values of methods (composite methods) whose # request values are a union of requests by other methods (simple methods). # fit_transform and fit_predict are the only composite methods we have in # scikit-learn. class SimpleEstimator(BaseEstimator): # This class should have every set_{method}_request def fit(self, X, y, foo=None, bar=None): pass # pragma: no cover def predict(self, X, foo=None, bar=None): pass # pragma: no cover def transform(self, X, other_param=None): pass # pragma: no cover est = SimpleEstimator() # Since no request is set for fit or predict or transform, the request for # fit_transform and fit_predict should also be empty. assert est.get_metadata_routing().fit_transform.requests == { "bar": None, "foo": None, "other_param": None, } assert est.get_metadata_routing().fit_predict.requests == {"bar": None, "foo": None} # setting the request on only one of them should raise an error est.set_fit_request(foo=True, bar="test") with pytest.raises(ValueError, match="Conflicting metadata requests for"): est.get_metadata_routing().fit_predict # setting the request on the other one should fail if not the same as the # first method est.set_predict_request(bar=True) with pytest.raises(ValueError, match="Conflicting metadata requests for"): est.get_metadata_routing().fit_predict # now the requests are consistent and getting the requests for fit_predict # shouldn't raise. est.set_predict_request(foo=True, bar="test") est.get_metadata_routing().fit_predict # setting the request for a none-overlapping parameter would merge them # together. est.set_transform_request(other_param=True) assert est.get_metadata_routing().fit_transform.requests == { "bar": "test", "foo": True, "other_param": True, } def test_no_feature_flag_raises_error(): """Test that when feature flag disabled, set_{method}_requests raises.""" with config_context(enable_metadata_routing=False): with pytest.raises(RuntimeError, match="This method is only available"): ConsumingClassifier().set_fit_request(sample_weight=True) def test_none_metadata_passed(): """Test that passing None as metadata when not requested doesn't raise""" MetaRegressor(estimator=ConsumingRegressor()).fit(X, y, sample_weight=None) def test_no_metadata_always_works(): """Test that when no metadata is passed, having a meta-estimator which does not yet support metadata routing works. Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28246 """ class Estimator(_RoutingNotSupportedMixin, BaseEstimator): def fit(self, X, y, metadata=None): return self # This passes since no metadata is passed. MetaRegressor(estimator=Estimator()).fit(X, y) # This fails since metadata is passed but Estimator() does not support it. with pytest.raises( NotImplementedError, match="Estimator has not implemented metadata routing yet." ): MetaRegressor(estimator=Estimator()).fit(X, y, metadata=my_groups) def test_unsetmetadatapassederror_correct(): """Test that UnsetMetadataPassedError raises the correct error message when set_{method}_request is not set in nested cases.""" weighted_meta = WeightedMetaClassifier(estimator=ConsumingClassifier()) pipe = SimplePipeline([weighted_meta]) msg = re.escape( "[metadata] are passed but are not explicitly set as requested or not requested" " for ConsumingClassifier.fit, which is used within WeightedMetaClassifier.fit." " Call `ConsumingClassifier.set_fit_request({metadata}=True/False)` for each" " metadata you want to request/ignore." ) with pytest.raises(UnsetMetadataPassedError, match=msg): pipe.fit(X, y, metadata="blah") def test_unsetmetadatapassederror_correct_for_composite_methods(): """Test that UnsetMetadataPassedError raises the correct error message when composite metadata request methods are not set in nested cases.""" consuming_transformer = ConsumingTransformer() pipe = Pipeline([("consuming_transformer", consuming_transformer)]) msg = re.escape( "[metadata] are passed but are not explicitly set as requested or not requested" " for ConsumingTransformer.fit_transform, which is used within" " Pipeline.fit_transform. Call" " `ConsumingTransformer.set_fit_request({metadata}=True/False)" ".set_transform_request({metadata}=True/False)`" " for each metadata you want to request/ignore." ) with pytest.raises(UnsetMetadataPassedError, match=msg): pipe.fit_transform(X, y, metadata="blah") def test_unbound_set_methods_work(): """Tests that if the set_{method}_request is unbound, it still works. Also test that passing positional arguments to the set_{method}_request fails with the right TypeError message. Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28632 """ class A(BaseEstimator): def fit(self, X, y, sample_weight=None): return self error_message = re.escape( "set_fit_request() takes 0 positional argument but 1 were given" ) # Test positional arguments error before making the descriptor method unbound. with pytest.raises(TypeError, match=error_message): A().set_fit_request(True) # This somehow makes the descriptor method unbound, which results in the `instance` # argument being None, and instead `self` being passed as a positional argument # to the descriptor method. A.set_fit_request = A.set_fit_request # This should pass as usual A().set_fit_request(sample_weight=True) # Test positional arguments error after making the descriptor method unbound. with pytest.raises(TypeError, match=error_message): A().set_fit_request(True)