1112 lines
38 KiB
Python
1112 lines
38 KiB
Python
"""
|
|
Metadata Routing Utility Tests
|
|
"""
|
|
|
|
# Author: Adrin Jalali <adrin.jalali@gmail.com>
|
|
# License: BSD 3 clause
|
|
|
|
import re
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from sklearn import config_context
|
|
from sklearn.base import (
|
|
BaseEstimator,
|
|
clone,
|
|
)
|
|
from sklearn.exceptions import UnsetMetadataPassedError
|
|
from sklearn.linear_model import LinearRegression
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.tests.metadata_routing_common import (
|
|
ConsumingClassifier,
|
|
ConsumingRegressor,
|
|
ConsumingTransformer,
|
|
MetaRegressor,
|
|
MetaTransformer,
|
|
NonConsumingClassifier,
|
|
WeightedMetaClassifier,
|
|
WeightedMetaRegressor,
|
|
_Registry,
|
|
assert_request_equal,
|
|
assert_request_is_empty,
|
|
check_recorded_metadata,
|
|
)
|
|
from sklearn.utils import metadata_routing
|
|
from sklearn.utils._metadata_requests import (
|
|
COMPOSITE_METHODS,
|
|
METHODS,
|
|
SIMPLE_METHODS,
|
|
MethodMetadataRequest,
|
|
MethodPair,
|
|
_MetadataRequester,
|
|
request_is_alias,
|
|
request_is_valid,
|
|
)
|
|
from sklearn.utils.metadata_routing import (
|
|
MetadataRequest,
|
|
MetadataRouter,
|
|
MethodMapping,
|
|
_RoutingNotSupportedMixin,
|
|
get_routing_for_object,
|
|
process_routing,
|
|
)
|
|
from sklearn.utils.validation import check_is_fitted
|
|
|
|
rng = np.random.RandomState(42)
|
|
N, M = 100, 4
|
|
X = rng.rand(N, M)
|
|
y = rng.randint(0, 2, size=N)
|
|
my_groups = rng.randint(0, 10, size=N)
|
|
my_weights = rng.rand(N)
|
|
my_other_weights = rng.rand(N)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def enable_slep006():
|
|
"""Enable SLEP006 for all tests."""
|
|
with config_context(enable_metadata_routing=True):
|
|
yield
|
|
|
|
|
|
class SimplePipeline(BaseEstimator):
|
|
"""A very simple pipeline, assuming the last step is always a predictor.
|
|
|
|
Parameters
|
|
----------
|
|
steps : iterable of objects
|
|
An iterable of transformers with the last step being a predictor.
|
|
"""
|
|
|
|
def __init__(self, steps):
|
|
self.steps = steps
|
|
|
|
def fit(self, X, y, **fit_params):
|
|
self.steps_ = []
|
|
params = process_routing(self, "fit", **fit_params)
|
|
X_transformed = X
|
|
for i, step in enumerate(self.steps[:-1]):
|
|
transformer = clone(step).fit(
|
|
X_transformed, y, **params.get(f"step_{i}").fit
|
|
)
|
|
self.steps_.append(transformer)
|
|
X_transformed = transformer.transform(
|
|
X_transformed, **params.get(f"step_{i}").transform
|
|
)
|
|
|
|
self.steps_.append(
|
|
clone(self.steps[-1]).fit(X_transformed, y, **params.predictor.fit)
|
|
)
|
|
return self
|
|
|
|
def predict(self, X, **predict_params):
|
|
check_is_fitted(self)
|
|
X_transformed = X
|
|
params = process_routing(self, "predict", **predict_params)
|
|
for i, step in enumerate(self.steps_[:-1]):
|
|
X_transformed = step.transform(X, **params.get(f"step_{i}").transform)
|
|
|
|
return self.steps_[-1].predict(X_transformed, **params.predictor.predict)
|
|
|
|
def get_metadata_routing(self):
|
|
router = MetadataRouter(owner=self.__class__.__name__)
|
|
for i, step in enumerate(self.steps[:-1]):
|
|
router.add(
|
|
**{f"step_{i}": step},
|
|
method_mapping=MethodMapping()
|
|
.add(caller="fit", callee="fit")
|
|
.add(caller="fit", callee="transform")
|
|
.add(caller="predict", callee="transform"),
|
|
)
|
|
router.add(
|
|
predictor=self.steps[-1],
|
|
method_mapping=MethodMapping()
|
|
.add(caller="fit", callee="fit")
|
|
.add(caller="predict", callee="predict"),
|
|
)
|
|
return router
|
|
|
|
|
|
def test_assert_request_is_empty():
|
|
requests = MetadataRequest(owner="test")
|
|
assert_request_is_empty(requests)
|
|
|
|
requests.fit.add_request(param="foo", alias=None)
|
|
# this should still work, since None is the default value
|
|
assert_request_is_empty(requests)
|
|
|
|
requests.fit.add_request(param="bar", alias="value")
|
|
with pytest.raises(AssertionError):
|
|
# now requests is no more empty
|
|
assert_request_is_empty(requests)
|
|
|
|
# but one can exclude a method
|
|
assert_request_is_empty(requests, exclude="fit")
|
|
|
|
requests.score.add_request(param="carrot", alias=True)
|
|
with pytest.raises(AssertionError):
|
|
# excluding `fit` is not enough
|
|
assert_request_is_empty(requests, exclude="fit")
|
|
|
|
# and excluding both fit and score would avoid an exception
|
|
assert_request_is_empty(requests, exclude=["fit", "score"])
|
|
|
|
# test if a router is empty
|
|
assert_request_is_empty(
|
|
MetadataRouter(owner="test")
|
|
.add_self_request(WeightedMetaRegressor(estimator=None))
|
|
.add(
|
|
estimator=ConsumingRegressor(),
|
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
|
)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"estimator",
|
|
[
|
|
ConsumingClassifier(registry=_Registry()),
|
|
ConsumingRegressor(registry=_Registry()),
|
|
ConsumingTransformer(registry=_Registry()),
|
|
WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()),
|
|
WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()),
|
|
],
|
|
)
|
|
def test_estimator_puts_self_in_registry(estimator):
|
|
"""Check that an estimator puts itself in the registry upon fit."""
|
|
estimator.fit(X, y)
|
|
assert estimator in estimator.registry
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"val, res",
|
|
[
|
|
(False, False),
|
|
(True, False),
|
|
(None, False),
|
|
("$UNUSED$", False),
|
|
("$WARN$", False),
|
|
("invalid-input", False),
|
|
("valid_arg", True),
|
|
],
|
|
)
|
|
def test_request_type_is_alias(val, res):
|
|
# Test request_is_alias
|
|
assert request_is_alias(val) == res
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"val, res",
|
|
[
|
|
(False, True),
|
|
(True, True),
|
|
(None, True),
|
|
("$UNUSED$", True),
|
|
("$WARN$", True),
|
|
("invalid-input", False),
|
|
("alias_arg", False),
|
|
],
|
|
)
|
|
def test_request_type_is_valid(val, res):
|
|
# Test request_is_valid
|
|
assert request_is_valid(val) == res
|
|
|
|
|
|
def test_default_requests():
|
|
class OddEstimator(BaseEstimator):
|
|
__metadata_request__fit = {
|
|
# set a different default request
|
|
"sample_weight": True
|
|
} # type: ignore
|
|
|
|
odd_request = get_routing_for_object(OddEstimator())
|
|
assert odd_request.fit.requests == {"sample_weight": True}
|
|
|
|
# check other test estimators
|
|
assert not len(get_routing_for_object(NonConsumingClassifier()).fit.requests)
|
|
assert_request_is_empty(NonConsumingClassifier().get_metadata_routing())
|
|
|
|
trs_request = get_routing_for_object(ConsumingTransformer())
|
|
assert trs_request.fit.requests == {
|
|
"sample_weight": None,
|
|
"metadata": None,
|
|
}
|
|
assert trs_request.transform.requests == {"metadata": None, "sample_weight": None}
|
|
assert_request_is_empty(trs_request)
|
|
|
|
est_request = get_routing_for_object(ConsumingClassifier())
|
|
assert est_request.fit.requests == {
|
|
"sample_weight": None,
|
|
"metadata": None,
|
|
}
|
|
assert_request_is_empty(est_request)
|
|
|
|
|
|
def test_default_request_override():
|
|
"""Test that default requests are correctly overridden regardless of the ASCII order
|
|
of the class names, hence testing small and capital letter class name starts.
|
|
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28430
|
|
"""
|
|
|
|
class Base(BaseEstimator):
|
|
__metadata_request__split = {"groups": True}
|
|
|
|
class class_1(Base):
|
|
__metadata_request__split = {"groups": "sample_domain"}
|
|
|
|
class Class_1(Base):
|
|
__metadata_request__split = {"groups": "sample_domain"}
|
|
|
|
assert_request_equal(
|
|
class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}}
|
|
)
|
|
assert_request_equal(
|
|
Class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}}
|
|
)
|
|
|
|
|
|
def test_process_routing_invalid_method():
|
|
with pytest.raises(TypeError, match="Can only route and process input"):
|
|
process_routing(ConsumingClassifier(), "invalid_method", groups=my_groups)
|
|
|
|
|
|
def test_process_routing_invalid_object():
|
|
class InvalidObject:
|
|
pass
|
|
|
|
with pytest.raises(AttributeError, match="either implement the routing method"):
|
|
process_routing(InvalidObject(), "fit", groups=my_groups)
|
|
|
|
|
|
@pytest.mark.parametrize("method", METHODS)
|
|
@pytest.mark.parametrize("default", [None, "default", []])
|
|
def test_process_routing_empty_params_get_with_default(method, default):
|
|
empty_params = {}
|
|
routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params)
|
|
|
|
# Behaviour should be an empty dictionary returned for each method when retrieved.
|
|
params_for_method = routed_params[method]
|
|
assert isinstance(params_for_method, dict)
|
|
assert set(params_for_method.keys()) == set(METHODS)
|
|
|
|
# No default to `get` should be equivalent to the default
|
|
default_params_for_method = routed_params.get(method, default=default)
|
|
assert default_params_for_method == params_for_method
|
|
|
|
|
|
def test_simple_metadata_routing():
|
|
# Tests that metadata is properly routed
|
|
|
|
# The underlying estimator doesn't accept or request metadata
|
|
clf = WeightedMetaClassifier(estimator=NonConsumingClassifier())
|
|
clf.fit(X, y)
|
|
|
|
# Meta-estimator consumes sample_weight, but doesn't forward it to the underlying
|
|
# estimator
|
|
clf = WeightedMetaClassifier(estimator=NonConsumingClassifier())
|
|
clf.fit(X, y, sample_weight=my_weights)
|
|
|
|
# If the estimator accepts the metadata but doesn't explicitly say it doesn't
|
|
# need it, there's an error
|
|
clf = WeightedMetaClassifier(estimator=ConsumingClassifier())
|
|
err_message = (
|
|
"[sample_weight] are passed but are not explicitly set as requested or"
|
|
" not requested for ConsumingClassifier.fit"
|
|
)
|
|
with pytest.raises(ValueError, match=re.escape(err_message)):
|
|
clf.fit(X, y, sample_weight=my_weights)
|
|
|
|
# Explicitly saying the estimator doesn't need it, makes the error go away,
|
|
# because in this case `WeightedMetaClassifier` consumes `sample_weight`. If
|
|
# there was no consumer of sample_weight, passing it would result in an
|
|
# error.
|
|
clf = WeightedMetaClassifier(
|
|
estimator=ConsumingClassifier().set_fit_request(sample_weight=False)
|
|
)
|
|
# this doesn't raise since WeightedMetaClassifier itself is a consumer,
|
|
# and passing metadata to the consumer directly is fine regardless of its
|
|
# metadata_request values.
|
|
clf.fit(X, y, sample_weight=my_weights)
|
|
check_recorded_metadata(clf.estimator_, "fit")
|
|
|
|
# Requesting a metadata will make the meta-estimator forward it correctly
|
|
clf = WeightedMetaClassifier(
|
|
estimator=ConsumingClassifier().set_fit_request(sample_weight=True)
|
|
)
|
|
clf.fit(X, y, sample_weight=my_weights)
|
|
check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights)
|
|
|
|
# And requesting it with an alias
|
|
clf = WeightedMetaClassifier(
|
|
estimator=ConsumingClassifier().set_fit_request(
|
|
sample_weight="alternative_weight"
|
|
)
|
|
)
|
|
clf.fit(X, y, alternative_weight=my_weights)
|
|
check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights)
|
|
|
|
|
|
def test_nested_routing():
|
|
# check if metadata is routed in a nested routing situation.
|
|
pipeline = SimplePipeline(
|
|
[
|
|
MetaTransformer(
|
|
transformer=ConsumingTransformer()
|
|
.set_fit_request(metadata=True, sample_weight=False)
|
|
.set_transform_request(sample_weight=True, metadata=False)
|
|
),
|
|
WeightedMetaRegressor(
|
|
estimator=ConsumingRegressor()
|
|
.set_fit_request(sample_weight="inner_weights", metadata=False)
|
|
.set_predict_request(sample_weight=False)
|
|
).set_fit_request(sample_weight="outer_weights"),
|
|
]
|
|
)
|
|
w1, w2, w3 = [1], [2], [3]
|
|
pipeline.fit(
|
|
X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3
|
|
)
|
|
check_recorded_metadata(
|
|
pipeline.steps_[0].transformer_, "fit", metadata=my_groups, sample_weight=None
|
|
)
|
|
check_recorded_metadata(
|
|
pipeline.steps_[0].transformer_, "transform", sample_weight=w1, metadata=None
|
|
)
|
|
check_recorded_metadata(pipeline.steps_[1], "fit", sample_weight=w2)
|
|
check_recorded_metadata(pipeline.steps_[1].estimator_, "fit", sample_weight=w3)
|
|
|
|
pipeline.predict(X, sample_weight=w3)
|
|
check_recorded_metadata(
|
|
pipeline.steps_[0].transformer_, "transform", sample_weight=w3, metadata=None
|
|
)
|
|
|
|
|
|
def test_nested_routing_conflict():
|
|
# check if an error is raised if there's a conflict between keys
|
|
pipeline = SimplePipeline(
|
|
[
|
|
MetaTransformer(
|
|
transformer=ConsumingTransformer()
|
|
.set_fit_request(metadata=True, sample_weight=False)
|
|
.set_transform_request(sample_weight=True)
|
|
),
|
|
WeightedMetaRegressor(
|
|
estimator=ConsumingRegressor().set_fit_request(sample_weight=True)
|
|
).set_fit_request(sample_weight="outer_weights"),
|
|
]
|
|
)
|
|
w1, w2 = [1], [2]
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=(
|
|
re.escape(
|
|
"In WeightedMetaRegressor, there is a conflict on sample_weight between"
|
|
" what is requested for this estimator and what is requested by its"
|
|
" children. You can resolve this conflict by using an alias for the"
|
|
" child estimator(s) requested metadata."
|
|
)
|
|
),
|
|
):
|
|
pipeline.fit(X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2)
|
|
|
|
|
|
def test_invalid_metadata():
|
|
# check that passing wrong metadata raises an error
|
|
trs = MetaTransformer(
|
|
transformer=ConsumingTransformer().set_transform_request(sample_weight=True)
|
|
)
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=(re.escape("transform got unexpected argument(s) {'other_param'}")),
|
|
):
|
|
trs.fit(X, y).transform(X, other_param=my_weights)
|
|
|
|
# passing a metadata which is not requested by any estimator should also raise
|
|
trs = MetaTransformer(
|
|
transformer=ConsumingTransformer().set_transform_request(sample_weight=False)
|
|
)
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=(re.escape("transform got unexpected argument(s) {'sample_weight'}")),
|
|
):
|
|
trs.fit(X, y).transform(X, sample_weight=my_weights)
|
|
|
|
|
|
def test_get_metadata_routing():
|
|
class TestDefaultsBadMethodName(_MetadataRequester):
|
|
__metadata_request__fit = {
|
|
"sample_weight": None,
|
|
"my_param": None,
|
|
}
|
|
__metadata_request__score = {
|
|
"sample_weight": None,
|
|
"my_param": True,
|
|
"my_other_param": None,
|
|
}
|
|
# this will raise an error since we don't understand "other_method" as a method
|
|
__metadata_request__other_method = {"my_param": True}
|
|
|
|
class TestDefaults(_MetadataRequester):
|
|
__metadata_request__fit = {
|
|
"sample_weight": None,
|
|
"my_other_param": None,
|
|
}
|
|
__metadata_request__score = {
|
|
"sample_weight": None,
|
|
"my_param": True,
|
|
"my_other_param": None,
|
|
}
|
|
__metadata_request__predict = {"my_param": True}
|
|
|
|
with pytest.raises(
|
|
AttributeError, match="'MetadataRequest' object has no attribute 'other_method'"
|
|
):
|
|
TestDefaultsBadMethodName().get_metadata_routing()
|
|
|
|
expected = {
|
|
"score": {
|
|
"my_param": True,
|
|
"my_other_param": None,
|
|
"sample_weight": None,
|
|
},
|
|
"fit": {
|
|
"my_other_param": None,
|
|
"sample_weight": None,
|
|
},
|
|
"predict": {"my_param": True},
|
|
}
|
|
assert_request_equal(TestDefaults().get_metadata_routing(), expected)
|
|
|
|
est = TestDefaults().set_score_request(my_param="other_param")
|
|
expected = {
|
|
"score": {
|
|
"my_param": "other_param",
|
|
"my_other_param": None,
|
|
"sample_weight": None,
|
|
},
|
|
"fit": {
|
|
"my_other_param": None,
|
|
"sample_weight": None,
|
|
},
|
|
"predict": {"my_param": True},
|
|
}
|
|
assert_request_equal(est.get_metadata_routing(), expected)
|
|
|
|
est = TestDefaults().set_fit_request(sample_weight=True)
|
|
expected = {
|
|
"score": {
|
|
"my_param": True,
|
|
"my_other_param": None,
|
|
"sample_weight": None,
|
|
},
|
|
"fit": {
|
|
"my_other_param": None,
|
|
"sample_weight": True,
|
|
},
|
|
"predict": {"my_param": True},
|
|
}
|
|
assert_request_equal(est.get_metadata_routing(), expected)
|
|
|
|
|
|
def test_setting_default_requests():
|
|
# Test _get_default_requests method
|
|
test_cases = dict()
|
|
|
|
class ExplicitRequest(BaseEstimator):
|
|
# `fit` doesn't accept `props` explicitly, but we want to request it
|
|
__metadata_request__fit = {"prop": None}
|
|
|
|
def fit(self, X, y, **kwargs):
|
|
return self
|
|
|
|
test_cases[ExplicitRequest] = {"prop": None}
|
|
|
|
class ExplicitRequestOverwrite(BaseEstimator):
|
|
# `fit` explicitly accepts `props`, but we want to change the default
|
|
# request value from None to True
|
|
__metadata_request__fit = {"prop": True}
|
|
|
|
def fit(self, X, y, prop=None, **kwargs):
|
|
return self
|
|
|
|
test_cases[ExplicitRequestOverwrite] = {"prop": True}
|
|
|
|
class ImplicitRequest(BaseEstimator):
|
|
# `fit` requests `prop` and the default None should be used
|
|
def fit(self, X, y, prop=None, **kwargs):
|
|
return self
|
|
|
|
test_cases[ImplicitRequest] = {"prop": None}
|
|
|
|
class ImplicitRequestRemoval(BaseEstimator):
|
|
# `fit` (in this class or a parent) requests `prop`, but we don't want
|
|
# it requested at all.
|
|
__metadata_request__fit = {"prop": metadata_routing.UNUSED}
|
|
|
|
def fit(self, X, y, prop=None, **kwargs):
|
|
return self
|
|
|
|
test_cases[ImplicitRequestRemoval] = {}
|
|
|
|
for Klass, requests in test_cases.items():
|
|
assert get_routing_for_object(Klass()).fit.requests == requests
|
|
assert_request_is_empty(Klass().get_metadata_routing(), exclude="fit")
|
|
Klass().fit(None, None) # for coverage
|
|
|
|
|
|
def test_removing_non_existing_param_raises():
|
|
"""Test that removing a metadata using UNUSED which doesn't exist raises."""
|
|
|
|
class InvalidRequestRemoval(BaseEstimator):
|
|
# `fit` (in this class or a parent) requests `prop`, but we don't want
|
|
# it requested at all.
|
|
__metadata_request__fit = {"prop": metadata_routing.UNUSED}
|
|
|
|
def fit(self, X, y, **kwargs):
|
|
return self
|
|
|
|
with pytest.raises(ValueError, match="Trying to remove parameter"):
|
|
InvalidRequestRemoval().get_metadata_routing()
|
|
|
|
|
|
def test_method_metadata_request():
|
|
mmr = MethodMetadataRequest(owner="test", method="fit")
|
|
|
|
with pytest.raises(ValueError, match="The alias you're setting for"):
|
|
mmr.add_request(param="foo", alias=1.4)
|
|
|
|
mmr.add_request(param="foo", alias=None)
|
|
assert mmr.requests == {"foo": None}
|
|
mmr.add_request(param="foo", alias=False)
|
|
assert mmr.requests == {"foo": False}
|
|
mmr.add_request(param="foo", alias=True)
|
|
assert mmr.requests == {"foo": True}
|
|
mmr.add_request(param="foo", alias="foo")
|
|
assert mmr.requests == {"foo": True}
|
|
mmr.add_request(param="foo", alias="bar")
|
|
assert mmr.requests == {"foo": "bar"}
|
|
assert mmr._get_param_names(return_alias=False) == {"foo"}
|
|
assert mmr._get_param_names(return_alias=True) == {"bar"}
|
|
|
|
|
|
def test_get_routing_for_object():
|
|
class Consumer(BaseEstimator):
|
|
__metadata_request__fit = {"prop": None}
|
|
|
|
assert_request_is_empty(get_routing_for_object(None))
|
|
assert_request_is_empty(get_routing_for_object(object()))
|
|
|
|
mr = MetadataRequest(owner="test")
|
|
mr.fit.add_request(param="foo", alias="bar")
|
|
mr_factory = get_routing_for_object(mr)
|
|
assert_request_is_empty(mr_factory, exclude="fit")
|
|
assert mr_factory.fit.requests == {"foo": "bar"}
|
|
|
|
mr = get_routing_for_object(Consumer())
|
|
assert_request_is_empty(mr, exclude="fit")
|
|
assert mr.fit.requests == {"prop": None}
|
|
|
|
|
|
def test_metadata_request_consumes_method():
|
|
"""Test that MetadataRequest().consumes() method works as expected."""
|
|
request = MetadataRouter(owner="test")
|
|
assert request.consumes(method="fit", params={"foo"}) == set()
|
|
|
|
request = MetadataRequest(owner="test")
|
|
request.fit.add_request(param="foo", alias=True)
|
|
assert request.consumes(method="fit", params={"foo"}) == {"foo"}
|
|
|
|
request = MetadataRequest(owner="test")
|
|
request.fit.add_request(param="foo", alias="bar")
|
|
assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"}
|
|
|
|
|
|
def test_metadata_router_consumes_method():
|
|
"""Test that MetadataRouter().consumes method works as expected."""
|
|
# having it here instead of parametrizing the test since `set_fit_request`
|
|
# is not available while collecting the tests.
|
|
cases = [
|
|
(
|
|
WeightedMetaRegressor(
|
|
estimator=ConsumingRegressor().set_fit_request(sample_weight=True)
|
|
),
|
|
{"sample_weight"},
|
|
{"sample_weight"},
|
|
),
|
|
(
|
|
WeightedMetaRegressor(
|
|
estimator=ConsumingRegressor().set_fit_request(
|
|
sample_weight="my_weights"
|
|
)
|
|
),
|
|
{"my_weights", "sample_weight"},
|
|
{"my_weights"},
|
|
),
|
|
]
|
|
|
|
for obj, input, output in cases:
|
|
assert obj.get_metadata_routing().consumes(method="fit", params=input) == output
|
|
|
|
|
|
def test_metaestimator_warnings():
|
|
class WeightedMetaRegressorWarn(WeightedMetaRegressor):
|
|
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
|
|
|
|
with pytest.warns(
|
|
UserWarning, match="Support for .* has recently been added to this class"
|
|
):
|
|
WeightedMetaRegressorWarn(
|
|
estimator=LinearRegression().set_fit_request(sample_weight=False)
|
|
).fit(X, y, sample_weight=my_weights)
|
|
|
|
|
|
def test_estimator_warnings():
|
|
class ConsumingRegressorWarn(ConsumingRegressor):
|
|
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
|
|
|
|
with pytest.warns(
|
|
UserWarning, match="Support for .* has recently been added to this class"
|
|
):
|
|
MetaRegressor(estimator=ConsumingRegressorWarn()).fit(
|
|
X, y, sample_weight=my_weights
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"obj, string",
|
|
[
|
|
(
|
|
MethodMetadataRequest(owner="test", method="fit").add_request(
|
|
param="foo", alias="bar"
|
|
),
|
|
"{'foo': 'bar'}",
|
|
),
|
|
(
|
|
MetadataRequest(owner="test"),
|
|
"{}",
|
|
),
|
|
(
|
|
MetadataRouter(owner="test").add(
|
|
estimator=ConsumingRegressor(),
|
|
method_mapping=MethodMapping().add(caller="predict", callee="predict"),
|
|
),
|
|
(
|
|
"{'estimator': {'mapping': [{'caller': 'predict', 'callee':"
|
|
" 'predict'}], 'router': {'fit': {'sample_weight': None, 'metadata':"
|
|
" None}, 'partial_fit': {'sample_weight': None, 'metadata': None},"
|
|
" 'predict': {'sample_weight': None, 'metadata': None}, 'score':"
|
|
" {'sample_weight': None, 'metadata': None}}}}"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_string_representations(obj, string):
|
|
assert str(obj) == string
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"obj, method, inputs, err_cls, err_msg",
|
|
[
|
|
(
|
|
MethodMapping(),
|
|
"add",
|
|
{"caller": "fit", "callee": "invalid"},
|
|
ValueError,
|
|
"Given callee",
|
|
),
|
|
(
|
|
MethodMapping(),
|
|
"add",
|
|
{"caller": "invalid", "callee": "fit"},
|
|
ValueError,
|
|
"Given caller",
|
|
),
|
|
(
|
|
MetadataRouter(owner="test"),
|
|
"add_self_request",
|
|
{"obj": MetadataRouter(owner="test")},
|
|
ValueError,
|
|
"Given `obj` is neither a `MetadataRequest` nor does it implement",
|
|
),
|
|
(
|
|
ConsumingClassifier(),
|
|
"set_fit_request",
|
|
{"invalid": True},
|
|
TypeError,
|
|
"Unexpected args",
|
|
),
|
|
],
|
|
)
|
|
def test_validations(obj, method, inputs, err_cls, err_msg):
|
|
with pytest.raises(err_cls, match=err_msg):
|
|
getattr(obj, method)(**inputs)
|
|
|
|
|
|
def test_methodmapping():
|
|
mm = (
|
|
MethodMapping()
|
|
.add(caller="fit", callee="transform")
|
|
.add(caller="fit", callee="fit")
|
|
)
|
|
|
|
mm_list = list(mm)
|
|
assert mm_list[0] == ("fit", "transform")
|
|
assert mm_list[1] == ("fit", "fit")
|
|
|
|
mm = MethodMapping()
|
|
for method in METHODS:
|
|
mm.add(caller=method, callee=method)
|
|
assert MethodPair(method, method) in mm._routes
|
|
assert len(mm._routes) == len(METHODS)
|
|
|
|
mm = MethodMapping().add(caller="score", callee="score")
|
|
assert repr(mm) == "[{'caller': 'score', 'callee': 'score'}]"
|
|
|
|
|
|
def test_metadatarouter_add_self_request():
|
|
# adding a MetadataRequest as `self` adds a copy
|
|
request = MetadataRequest(owner="nested")
|
|
request.fit.add_request(param="param", alias=True)
|
|
router = MetadataRouter(owner="test").add_self_request(request)
|
|
assert str(router._self_request) == str(request)
|
|
# should be a copy, not the same object
|
|
assert router._self_request is not request
|
|
|
|
# one can add an estimator as self
|
|
est = ConsumingRegressor().set_fit_request(sample_weight="my_weights")
|
|
router = MetadataRouter(owner="test").add_self_request(obj=est)
|
|
assert str(router._self_request) == str(est.get_metadata_routing())
|
|
assert router._self_request is not est.get_metadata_routing()
|
|
|
|
# adding a consumer+router as self should only add the consumer part
|
|
est = WeightedMetaRegressor(
|
|
estimator=ConsumingRegressor().set_fit_request(sample_weight="nested_weights")
|
|
)
|
|
router = MetadataRouter(owner="test").add_self_request(obj=est)
|
|
# _get_metadata_request() returns the consumer part of the requests
|
|
assert str(router._self_request) == str(est._get_metadata_request())
|
|
# get_metadata_routing() returns the complete request set, consumer and
|
|
# router included.
|
|
assert str(router._self_request) != str(est.get_metadata_routing())
|
|
# it should be a copy, not the same object
|
|
assert router._self_request is not est._get_metadata_request()
|
|
|
|
|
|
def test_metadata_routing_add():
|
|
# adding one with a string `method_mapping`
|
|
router = MetadataRouter(owner="test").add(
|
|
est=ConsumingRegressor().set_fit_request(sample_weight="weights"),
|
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
|
)
|
|
assert (
|
|
str(router)
|
|
== "{'est': {'mapping': [{'caller': 'fit', 'callee': 'fit'}], 'router': {'fit':"
|
|
" {'sample_weight': 'weights', 'metadata': None}, 'partial_fit':"
|
|
" {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':"
|
|
" None, 'metadata': None}, 'score': {'sample_weight': None, 'metadata':"
|
|
" None}}}}"
|
|
)
|
|
|
|
# adding one with an instance of MethodMapping
|
|
router = MetadataRouter(owner="test").add(
|
|
method_mapping=MethodMapping().add(caller="fit", callee="score"),
|
|
est=ConsumingRegressor().set_score_request(sample_weight=True),
|
|
)
|
|
assert (
|
|
str(router)
|
|
== "{'est': {'mapping': [{'caller': 'fit', 'callee': 'score'}], 'router':"
|
|
" {'fit': {'sample_weight': None, 'metadata': None}, 'partial_fit':"
|
|
" {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':"
|
|
" None, 'metadata': None}, 'score': {'sample_weight': True, 'metadata':"
|
|
" None}}}}"
|
|
)
|
|
|
|
|
|
def test_metadata_routing_get_param_names():
|
|
router = (
|
|
MetadataRouter(owner="test")
|
|
.add_self_request(
|
|
WeightedMetaRegressor(estimator=ConsumingRegressor()).set_fit_request(
|
|
sample_weight="self_weights"
|
|
)
|
|
)
|
|
.add(
|
|
trs=ConsumingTransformer().set_fit_request(
|
|
sample_weight="transform_weights"
|
|
),
|
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
|
|
)
|
|
)
|
|
|
|
assert (
|
|
str(router)
|
|
== "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score':"
|
|
" {'sample_weight': None}}, 'trs': {'mapping': [{'caller': 'fit', 'callee':"
|
|
" 'fit'}], 'router': {'fit': {'sample_weight': 'transform_weights',"
|
|
" 'metadata': None}, 'transform': {'sample_weight': None, 'metadata': None},"
|
|
" 'inverse_transform': {'sample_weight': None, 'metadata': None}}}}"
|
|
)
|
|
|
|
assert router._get_param_names(
|
|
method="fit", return_alias=True, ignore_self_request=False
|
|
) == {"transform_weights", "metadata", "self_weights"}
|
|
# return_alias=False will return original names for "self"
|
|
assert router._get_param_names(
|
|
method="fit", return_alias=False, ignore_self_request=False
|
|
) == {"sample_weight", "metadata", "transform_weights"}
|
|
# ignoring self would remove "sample_weight"
|
|
assert router._get_param_names(
|
|
method="fit", return_alias=False, ignore_self_request=True
|
|
) == {"metadata", "transform_weights"}
|
|
# return_alias is ignored when ignore_self_request=True
|
|
assert router._get_param_names(
|
|
method="fit", return_alias=True, ignore_self_request=True
|
|
) == router._get_param_names(
|
|
method="fit", return_alias=False, ignore_self_request=True
|
|
)
|
|
|
|
|
|
def test_method_generation():
|
|
# Test if all required request methods are generated.
|
|
|
|
# TODO: these test classes can be moved to sklearn.utils._testing once we
|
|
# have a better idea of what the commonly used classes are.
|
|
class SimpleEstimator(BaseEstimator):
|
|
# This class should have no set_{method}_request
|
|
def fit(self, X, y):
|
|
pass # pragma: no cover
|
|
|
|
def fit_transform(self, X, y):
|
|
pass # pragma: no cover
|
|
|
|
def fit_predict(self, X, y):
|
|
pass # pragma: no cover
|
|
|
|
def partial_fit(self, X, y):
|
|
pass # pragma: no cover
|
|
|
|
def predict(self, X):
|
|
pass # pragma: no cover
|
|
|
|
def predict_proba(self, X):
|
|
pass # pragma: no cover
|
|
|
|
def predict_log_proba(self, X):
|
|
pass # pragma: no cover
|
|
|
|
def decision_function(self, X):
|
|
pass # pragma: no cover
|
|
|
|
def score(self, X, y):
|
|
pass # pragma: no cover
|
|
|
|
def split(self, X, y=None):
|
|
pass # pragma: no cover
|
|
|
|
def transform(self, X):
|
|
pass # pragma: no cover
|
|
|
|
def inverse_transform(self, X):
|
|
pass # pragma: no cover
|
|
|
|
for method in METHODS:
|
|
assert not hasattr(SimpleEstimator(), f"set_{method}_request")
|
|
|
|
class SimpleEstimator(BaseEstimator):
|
|
# This class should have every set_{method}_request
|
|
def fit(self, X, y, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def fit_transform(self, X, y, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def fit_predict(self, X, y, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def partial_fit(self, X, y, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def predict(self, X, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def predict_proba(self, X, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def predict_log_proba(self, X, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def decision_function(self, X, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def score(self, X, y, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def split(self, X, y=None, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def transform(self, X, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
def inverse_transform(self, X, sample_weight=None):
|
|
pass # pragma: no cover
|
|
|
|
# composite methods shouldn't have a corresponding set method.
|
|
for method in COMPOSITE_METHODS:
|
|
assert not hasattr(SimpleEstimator(), f"set_{method}_request")
|
|
|
|
# simple methods should have a corresponding set method.
|
|
for method in SIMPLE_METHODS:
|
|
assert hasattr(SimpleEstimator(), f"set_{method}_request")
|
|
|
|
|
|
def test_composite_methods():
|
|
# Test the behavior and the values of methods (composite methods) whose
|
|
# request values are a union of requests by other methods (simple methods).
|
|
# fit_transform and fit_predict are the only composite methods we have in
|
|
# scikit-learn.
|
|
class SimpleEstimator(BaseEstimator):
|
|
# This class should have every set_{method}_request
|
|
def fit(self, X, y, foo=None, bar=None):
|
|
pass # pragma: no cover
|
|
|
|
def predict(self, X, foo=None, bar=None):
|
|
pass # pragma: no cover
|
|
|
|
def transform(self, X, other_param=None):
|
|
pass # pragma: no cover
|
|
|
|
est = SimpleEstimator()
|
|
# Since no request is set for fit or predict or transform, the request for
|
|
# fit_transform and fit_predict should also be empty.
|
|
assert est.get_metadata_routing().fit_transform.requests == {
|
|
"bar": None,
|
|
"foo": None,
|
|
"other_param": None,
|
|
}
|
|
assert est.get_metadata_routing().fit_predict.requests == {"bar": None, "foo": None}
|
|
|
|
# setting the request on only one of them should raise an error
|
|
est.set_fit_request(foo=True, bar="test")
|
|
with pytest.raises(ValueError, match="Conflicting metadata requests for"):
|
|
est.get_metadata_routing().fit_predict
|
|
|
|
# setting the request on the other one should fail if not the same as the
|
|
# first method
|
|
est.set_predict_request(bar=True)
|
|
with pytest.raises(ValueError, match="Conflicting metadata requests for"):
|
|
est.get_metadata_routing().fit_predict
|
|
|
|
# now the requests are consistent and getting the requests for fit_predict
|
|
# shouldn't raise.
|
|
est.set_predict_request(foo=True, bar="test")
|
|
est.get_metadata_routing().fit_predict
|
|
|
|
# setting the request for a none-overlapping parameter would merge them
|
|
# together.
|
|
est.set_transform_request(other_param=True)
|
|
assert est.get_metadata_routing().fit_transform.requests == {
|
|
"bar": "test",
|
|
"foo": True,
|
|
"other_param": True,
|
|
}
|
|
|
|
|
|
def test_no_feature_flag_raises_error():
|
|
"""Test that when feature flag disabled, set_{method}_requests raises."""
|
|
with config_context(enable_metadata_routing=False):
|
|
with pytest.raises(RuntimeError, match="This method is only available"):
|
|
ConsumingClassifier().set_fit_request(sample_weight=True)
|
|
|
|
|
|
def test_none_metadata_passed():
|
|
"""Test that passing None as metadata when not requested doesn't raise"""
|
|
MetaRegressor(estimator=ConsumingRegressor()).fit(X, y, sample_weight=None)
|
|
|
|
|
|
def test_no_metadata_always_works():
|
|
"""Test that when no metadata is passed, having a meta-estimator which does
|
|
not yet support metadata routing works.
|
|
|
|
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28246
|
|
"""
|
|
|
|
class Estimator(_RoutingNotSupportedMixin, BaseEstimator):
|
|
def fit(self, X, y, metadata=None):
|
|
return self
|
|
|
|
# This passes since no metadata is passed.
|
|
MetaRegressor(estimator=Estimator()).fit(X, y)
|
|
# This fails since metadata is passed but Estimator() does not support it.
|
|
with pytest.raises(
|
|
NotImplementedError, match="Estimator has not implemented metadata routing yet."
|
|
):
|
|
MetaRegressor(estimator=Estimator()).fit(X, y, metadata=my_groups)
|
|
|
|
|
|
def test_unsetmetadatapassederror_correct():
|
|
"""Test that UnsetMetadataPassedError raises the correct error message when
|
|
set_{method}_request is not set in nested cases."""
|
|
weighted_meta = WeightedMetaClassifier(estimator=ConsumingClassifier())
|
|
pipe = SimplePipeline([weighted_meta])
|
|
msg = re.escape(
|
|
"[metadata] are passed but are not explicitly set as requested or not requested"
|
|
" for ConsumingClassifier.fit, which is used within WeightedMetaClassifier.fit."
|
|
" Call `ConsumingClassifier.set_fit_request({metadata}=True/False)` for each"
|
|
" metadata you want to request/ignore."
|
|
)
|
|
|
|
with pytest.raises(UnsetMetadataPassedError, match=msg):
|
|
pipe.fit(X, y, metadata="blah")
|
|
|
|
|
|
def test_unsetmetadatapassederror_correct_for_composite_methods():
|
|
"""Test that UnsetMetadataPassedError raises the correct error message when
|
|
composite metadata request methods are not set in nested cases."""
|
|
consuming_transformer = ConsumingTransformer()
|
|
pipe = Pipeline([("consuming_transformer", consuming_transformer)])
|
|
|
|
msg = re.escape(
|
|
"[metadata] are passed but are not explicitly set as requested or not requested"
|
|
" for ConsumingTransformer.fit_transform, which is used within"
|
|
" Pipeline.fit_transform. Call"
|
|
" `ConsumingTransformer.set_fit_request({metadata}=True/False)"
|
|
".set_transform_request({metadata}=True/False)`"
|
|
" for each metadata you want to request/ignore."
|
|
)
|
|
with pytest.raises(UnsetMetadataPassedError, match=msg):
|
|
pipe.fit_transform(X, y, metadata="blah")
|
|
|
|
|
|
def test_unbound_set_methods_work():
|
|
"""Tests that if the set_{method}_request is unbound, it still works.
|
|
|
|
Also test that passing positional arguments to the set_{method}_request fails
|
|
with the right TypeError message.
|
|
|
|
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28632
|
|
"""
|
|
|
|
class A(BaseEstimator):
|
|
def fit(self, X, y, sample_weight=None):
|
|
return self
|
|
|
|
error_message = re.escape(
|
|
"set_fit_request() takes 0 positional argument but 1 were given"
|
|
)
|
|
|
|
# Test positional arguments error before making the descriptor method unbound.
|
|
with pytest.raises(TypeError, match=error_message):
|
|
A().set_fit_request(True)
|
|
|
|
# This somehow makes the descriptor method unbound, which results in the `instance`
|
|
# argument being None, and instead `self` being passed as a positional argument
|
|
# to the descriptor method.
|
|
A.set_fit_request = A.set_fit_request
|
|
|
|
# This should pass as usual
|
|
A().set_fit_request(sample_weight=True)
|
|
|
|
# Test positional arguments error after making the descriptor method unbound.
|
|
with pytest.raises(TypeError, match=error_message):
|
|
A().set_fit_request(True)
|