153 lines
4.5 KiB
Python
153 lines
4.5 KiB
Python
|
import pytest
|
||
|
|
||
|
from sklearn.base import ClassifierMixin, clone
|
||
|
from sklearn.compose import make_column_transformer
|
||
|
from sklearn.datasets import load_iris
|
||
|
from sklearn.exceptions import NotFittedError
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
from sklearn.pipeline import make_pipeline
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
|
||
|
from sklearn.metrics import (
|
||
|
DetCurveDisplay,
|
||
|
PrecisionRecallDisplay,
|
||
|
RocCurveDisplay,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def data():
|
||
|
return load_iris(return_X_y=True)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def data_binary(data):
|
||
|
X, y = data
|
||
|
return X[y < 2], y[y < 2]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||
|
)
|
||
|
def test_display_curve_error_non_binary(pyplot, data, Display):
|
||
|
"""Check that a proper error is raised when only binary classification is
|
||
|
supported."""
|
||
|
X, y = data
|
||
|
clf = DecisionTreeClassifier().fit(X, y)
|
||
|
|
||
|
msg = (
|
||
|
"Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier"
|
||
|
)
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
Display.from_estimator(clf, X, y)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"response_method, msg",
|
||
|
[
|
||
|
(
|
||
|
"predict_proba",
|
||
|
"response method predict_proba is not defined in MyClassifier",
|
||
|
),
|
||
|
(
|
||
|
"decision_function",
|
||
|
"response method decision_function is not defined in MyClassifier",
|
||
|
),
|
||
|
(
|
||
|
"auto",
|
||
|
"response method decision_function or predict_proba is not "
|
||
|
"defined in MyClassifier",
|
||
|
),
|
||
|
(
|
||
|
"bad_method",
|
||
|
"response_method must be 'predict_proba', 'decision_function' or 'auto'",
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||
|
)
|
||
|
def test_display_curve_error_no_response(
|
||
|
pyplot,
|
||
|
data_binary,
|
||
|
response_method,
|
||
|
msg,
|
||
|
Display,
|
||
|
):
|
||
|
"""Check that a proper error is raised when the response method requested
|
||
|
is not defined for the given trained classifier."""
|
||
|
X, y = data_binary
|
||
|
|
||
|
class MyClassifier(ClassifierMixin):
|
||
|
def fit(self, X, y):
|
||
|
self.classes_ = [0, 1]
|
||
|
return self
|
||
|
|
||
|
clf = MyClassifier().fit(X, y)
|
||
|
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
Display.from_estimator(clf, X, y, response_method=response_method)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||
|
)
|
||
|
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||
|
def test_display_curve_estimator_name_multiple_calls(
|
||
|
pyplot,
|
||
|
data_binary,
|
||
|
Display,
|
||
|
constructor_name,
|
||
|
):
|
||
|
"""Check that passing `name` when calling `plot` will overwrite the original name
|
||
|
in the legend."""
|
||
|
X, y = data_binary
|
||
|
clf_name = "my hand-crafted name"
|
||
|
clf = LogisticRegression().fit(X, y)
|
||
|
y_pred = clf.predict_proba(X)[:, 1]
|
||
|
|
||
|
# safe guard for the binary if/else construction
|
||
|
assert constructor_name in ("from_estimator", "from_predictions")
|
||
|
|
||
|
if constructor_name == "from_estimator":
|
||
|
disp = Display.from_estimator(clf, X, y, name=clf_name)
|
||
|
else:
|
||
|
disp = Display.from_predictions(y, y_pred, name=clf_name)
|
||
|
assert disp.estimator_name == clf_name
|
||
|
pyplot.close("all")
|
||
|
disp.plot()
|
||
|
assert clf_name in disp.line_.get_label()
|
||
|
pyplot.close("all")
|
||
|
clf_name = "another_name"
|
||
|
disp.plot(name=clf_name)
|
||
|
assert clf_name in disp.line_.get_label()
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"clf",
|
||
|
[
|
||
|
LogisticRegression(),
|
||
|
make_pipeline(StandardScaler(), LogisticRegression()),
|
||
|
make_pipeline(
|
||
|
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||
|
)
|
||
|
def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display):
|
||
|
"""Check that a proper error is raised when the classifier is not
|
||
|
fitted."""
|
||
|
X, y = data_binary
|
||
|
# clone since we parametrize the test and the classifier will be fitted
|
||
|
# when testing the second and subsequent plotting function
|
||
|
model = clone(clf)
|
||
|
with pytest.raises(NotFittedError):
|
||
|
Display.from_estimator(model, X, y)
|
||
|
model.fit(X, y)
|
||
|
disp = Display.from_estimator(model, X, y)
|
||
|
assert model.__class__.__name__ in disp.line_.get_label()
|
||
|
assert disp.estimator_name == model.__class__.__name__
|