76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from sklearn.datasets import load_iris
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
|
|
|
from sklearn.metrics._plot.base import _get_response
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"estimator, err_msg, params",
|
|
[
|
|
(
|
|
DecisionTreeRegressor(),
|
|
"Expected 'estimator' to be a binary classifier",
|
|
{"response_method": "auto"},
|
|
),
|
|
(
|
|
DecisionTreeClassifier(),
|
|
"The class provided by 'pos_label' is unknown.",
|
|
{"response_method": "auto", "pos_label": "unknown"},
|
|
),
|
|
(
|
|
DecisionTreeClassifier(),
|
|
"fit on multiclass",
|
|
{"response_method": "predict_proba"},
|
|
),
|
|
],
|
|
)
|
|
def test_get_response_error(estimator, err_msg, params):
|
|
"""Check that we raise the proper error messages in `_get_response`."""
|
|
X, y = load_iris(return_X_y=True)
|
|
|
|
estimator.fit(X, y)
|
|
with pytest.raises(ValueError, match=err_msg):
|
|
_get_response(X, estimator, **params)
|
|
|
|
|
|
def test_get_response_predict_proba():
|
|
"""Check the behaviour of `_get_response` using `predict_proba`."""
|
|
X, y = load_iris(return_X_y=True)
|
|
X_binary, y_binary = X[:100], y[:100]
|
|
|
|
classifier = DecisionTreeClassifier().fit(X_binary, y_binary)
|
|
y_proba, pos_label = _get_response(
|
|
X_binary, classifier, response_method="predict_proba"
|
|
)
|
|
np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1])
|
|
assert pos_label == 1
|
|
|
|
y_proba, pos_label = _get_response(
|
|
X_binary, classifier, response_method="predict_proba", pos_label=0
|
|
)
|
|
np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0])
|
|
assert pos_label == 0
|
|
|
|
|
|
def test_get_response_decision_function():
|
|
"""Check the behaviour of `get_response` using `decision_function`."""
|
|
X, y = load_iris(return_X_y=True)
|
|
X_binary, y_binary = X[:100], y[:100]
|
|
|
|
classifier = LogisticRegression().fit(X_binary, y_binary)
|
|
y_score, pos_label = _get_response(
|
|
X_binary, classifier, response_method="decision_function"
|
|
)
|
|
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary))
|
|
assert pos_label == 1
|
|
|
|
y_score, pos_label = _get_response(
|
|
X_binary, classifier, response_method="decision_function", pos_label=0
|
|
)
|
|
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
|
|
assert pos_label == 0
|