1008 lines
38 KiB
Python
1008 lines
38 KiB
Python
"""Stacking classifier and regressor."""
|
|
|
|
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
|
|
# License: BSD 3 clause
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
from copy import deepcopy
|
|
from numbers import Integral
|
|
|
|
import numpy as np
|
|
import scipy.sparse as sparse
|
|
|
|
from ..base import clone
|
|
from ..base import ClassifierMixin, RegressorMixin, TransformerMixin
|
|
from ..base import is_classifier, is_regressor
|
|
from ..exceptions import NotFittedError
|
|
from ..utils._estimator_html_repr import _VisualBlock
|
|
|
|
from ._base import _fit_single_estimator
|
|
from ._base import _BaseHeterogeneousEnsemble
|
|
|
|
from ..linear_model import LogisticRegression
|
|
from ..linear_model import RidgeCV
|
|
|
|
from ..model_selection import cross_val_predict
|
|
from ..model_selection import check_cv
|
|
|
|
from ..preprocessing import LabelEncoder
|
|
|
|
from ..utils import Bunch
|
|
from ..utils.multiclass import check_classification_targets, type_of_target
|
|
from ..utils.metaestimators import available_if
|
|
from ..utils.validation import check_is_fitted
|
|
from ..utils.validation import column_or_1d
|
|
from ..utils.parallel import delayed, Parallel
|
|
from ..utils._param_validation import HasMethods, StrOptions
|
|
from ..utils.validation import _check_feature_names_in
|
|
|
|
|
|
def _estimator_has(attr):
|
|
"""Check if we can delegate a method to the underlying estimator.
|
|
|
|
First, we check the first fitted final estimator if available, otherwise we
|
|
check the unfitted final estimator.
|
|
"""
|
|
return lambda self: (
|
|
hasattr(self.final_estimator_, attr)
|
|
if hasattr(self, "final_estimator_")
|
|
else hasattr(self.final_estimator, attr)
|
|
)
|
|
|
|
|
|
class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta):
|
|
"""Base class for stacking method."""
|
|
|
|
_parameter_constraints: dict = {
|
|
"estimators": [list],
|
|
"final_estimator": [None, HasMethods("fit")],
|
|
"cv": ["cv_object", StrOptions({"prefit"})],
|
|
"n_jobs": [None, Integral],
|
|
"passthrough": ["boolean"],
|
|
"verbose": ["verbose"],
|
|
}
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
estimators,
|
|
final_estimator=None,
|
|
*,
|
|
cv=None,
|
|
stack_method="auto",
|
|
n_jobs=None,
|
|
verbose=0,
|
|
passthrough=False,
|
|
):
|
|
super().__init__(estimators=estimators)
|
|
self.final_estimator = final_estimator
|
|
self.cv = cv
|
|
self.stack_method = stack_method
|
|
self.n_jobs = n_jobs
|
|
self.verbose = verbose
|
|
self.passthrough = passthrough
|
|
|
|
def _clone_final_estimator(self, default):
|
|
if self.final_estimator is not None:
|
|
self.final_estimator_ = clone(self.final_estimator)
|
|
else:
|
|
self.final_estimator_ = clone(default)
|
|
|
|
def _concatenate_predictions(self, X, predictions):
|
|
"""Concatenate the predictions of each first layer learner and
|
|
possibly the input dataset `X`.
|
|
|
|
If `X` is sparse and `self.passthrough` is False, the output of
|
|
`transform` will be dense (the predictions). If `X` is sparse
|
|
and `self.passthrough` is True, the output of `transform` will
|
|
be sparse.
|
|
|
|
This helper is in charge of ensuring the predictions are 2D arrays and
|
|
it will drop one of the probability column when using probabilities
|
|
in the binary case. Indeed, the p(y|c=0) = 1 - p(y|c=1)
|
|
|
|
When `y` type is `"multilabel-indicator"`` and the method used is
|
|
`predict_proba`, `preds` can be either a `ndarray` of shape
|
|
`(n_samples, n_class)` or for some estimators a list of `ndarray`.
|
|
This function will drop one of the probability column in this situation as well.
|
|
"""
|
|
X_meta = []
|
|
for est_idx, preds in enumerate(predictions):
|
|
if isinstance(preds, list):
|
|
# `preds` is here a list of `n_targets` 2D ndarrays of
|
|
# `n_classes` columns. The k-th column contains the
|
|
# probabilities of the samples belonging the k-th class.
|
|
#
|
|
# Since those probabilities must sum to one for each sample,
|
|
# we can work with probabilities of `n_classes - 1` classes.
|
|
# Hence we drop the first column.
|
|
for pred in preds:
|
|
X_meta.append(pred[:, 1:])
|
|
elif preds.ndim == 1:
|
|
# Some estimator return a 1D array for predictions
|
|
# which must be 2-dimensional arrays.
|
|
X_meta.append(preds.reshape(-1, 1))
|
|
elif (
|
|
self.stack_method_[est_idx] == "predict_proba"
|
|
and len(self.classes_) == 2
|
|
):
|
|
# Remove the first column when using probabilities in
|
|
# binary classification because both features `preds` are perfectly
|
|
# collinear.
|
|
X_meta.append(preds[:, 1:])
|
|
else:
|
|
X_meta.append(preds)
|
|
|
|
self._n_feature_outs = [pred.shape[1] for pred in X_meta]
|
|
if self.passthrough:
|
|
X_meta.append(X)
|
|
if sparse.issparse(X):
|
|
return sparse.hstack(X_meta, format=X.format)
|
|
|
|
return np.hstack(X_meta)
|
|
|
|
@staticmethod
|
|
def _method_name(name, estimator, method):
|
|
if estimator == "drop":
|
|
return None
|
|
if method == "auto":
|
|
if getattr(estimator, "predict_proba", None):
|
|
return "predict_proba"
|
|
elif getattr(estimator, "decision_function", None):
|
|
return "decision_function"
|
|
else:
|
|
return "predict"
|
|
else:
|
|
if not hasattr(estimator, method):
|
|
raise ValueError(
|
|
"Underlying estimator {} does not implement the method {}.".format(
|
|
name, method
|
|
)
|
|
)
|
|
return method
|
|
|
|
def fit(self, X, y, sample_weight=None):
|
|
"""Fit the estimators.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
y : array-like of shape (n_samples,)
|
|
Target values.
|
|
|
|
sample_weight : array-like of shape (n_samples,) or default=None
|
|
Sample weights. If None, then samples are equally weighted.
|
|
Note that this is supported only if all underlying estimators
|
|
support sample weights.
|
|
|
|
.. versionchanged:: 0.23
|
|
when not None, `sample_weight` is passed to all underlying
|
|
estimators
|
|
|
|
Returns
|
|
-------
|
|
self : object
|
|
"""
|
|
|
|
self._validate_params()
|
|
|
|
# all_estimators contains all estimators, the one to be fitted and the
|
|
# 'drop' string.
|
|
names, all_estimators = self._validate_estimators()
|
|
self._validate_final_estimator()
|
|
|
|
stack_method = [self.stack_method] * len(all_estimators)
|
|
|
|
if self.cv == "prefit":
|
|
self.estimators_ = []
|
|
for estimator in all_estimators:
|
|
if estimator != "drop":
|
|
check_is_fitted(estimator)
|
|
self.estimators_.append(estimator)
|
|
else:
|
|
# Fit the base estimators on the whole training data. Those
|
|
# base estimators will be used in transform, predict, and
|
|
# predict_proba. They are exposed publicly.
|
|
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
|
delayed(_fit_single_estimator)(clone(est), X, y, sample_weight)
|
|
for est in all_estimators
|
|
if est != "drop"
|
|
)
|
|
|
|
self.named_estimators_ = Bunch()
|
|
est_fitted_idx = 0
|
|
for name_est, org_est in zip(names, all_estimators):
|
|
if org_est != "drop":
|
|
current_estimator = self.estimators_[est_fitted_idx]
|
|
self.named_estimators_[name_est] = current_estimator
|
|
est_fitted_idx += 1
|
|
if hasattr(current_estimator, "feature_names_in_"):
|
|
self.feature_names_in_ = current_estimator.feature_names_in_
|
|
else:
|
|
self.named_estimators_[name_est] = "drop"
|
|
|
|
self.stack_method_ = [
|
|
self._method_name(name, est, meth)
|
|
for name, est, meth in zip(names, all_estimators, stack_method)
|
|
]
|
|
|
|
if self.cv == "prefit":
|
|
# Generate predictions from prefit models
|
|
predictions = [
|
|
getattr(estimator, predict_method)(X)
|
|
for estimator, predict_method in zip(all_estimators, self.stack_method_)
|
|
if estimator != "drop"
|
|
]
|
|
else:
|
|
# To train the meta-classifier using the most data as possible, we use
|
|
# a cross-validation to obtain the output of the stacked estimators.
|
|
# To ensure that the data provided to each estimator are the same,
|
|
# we need to set the random state of the cv if there is one and we
|
|
# need to take a copy.
|
|
cv = check_cv(self.cv, y=y, classifier=is_classifier(self))
|
|
if hasattr(cv, "random_state") and cv.random_state is None:
|
|
cv.random_state = np.random.RandomState()
|
|
|
|
fit_params = (
|
|
{"sample_weight": sample_weight} if sample_weight is not None else None
|
|
)
|
|
predictions = Parallel(n_jobs=self.n_jobs)(
|
|
delayed(cross_val_predict)(
|
|
clone(est),
|
|
X,
|
|
y,
|
|
cv=deepcopy(cv),
|
|
method=meth,
|
|
n_jobs=self.n_jobs,
|
|
fit_params=fit_params,
|
|
verbose=self.verbose,
|
|
)
|
|
for est, meth in zip(all_estimators, self.stack_method_)
|
|
if est != "drop"
|
|
)
|
|
|
|
# Only not None or not 'drop' estimators will be used in transform.
|
|
# Remove the None from the method as well.
|
|
self.stack_method_ = [
|
|
meth
|
|
for (meth, est) in zip(self.stack_method_, all_estimators)
|
|
if est != "drop"
|
|
]
|
|
|
|
X_meta = self._concatenate_predictions(X, predictions)
|
|
_fit_single_estimator(
|
|
self.final_estimator_, X_meta, y, sample_weight=sample_weight
|
|
)
|
|
|
|
return self
|
|
|
|
@property
|
|
def n_features_in_(self):
|
|
"""Number of features seen during :term:`fit`."""
|
|
try:
|
|
check_is_fitted(self)
|
|
except NotFittedError as nfe:
|
|
raise AttributeError(
|
|
f"{self.__class__.__name__} object has no attribute n_features_in_"
|
|
) from nfe
|
|
return self.estimators_[0].n_features_in_
|
|
|
|
def _transform(self, X):
|
|
"""Concatenate and return the predictions of the estimators."""
|
|
check_is_fitted(self)
|
|
predictions = [
|
|
getattr(est, meth)(X)
|
|
for est, meth in zip(self.estimators_, self.stack_method_)
|
|
if est != "drop"
|
|
]
|
|
return self._concatenate_predictions(X, predictions)
|
|
|
|
def get_feature_names_out(self, input_features=None):
|
|
"""Get output feature names for transformation.
|
|
|
|
Parameters
|
|
----------
|
|
input_features : array-like of str or None, default=None
|
|
Input features. The input feature names are only used when `passthrough` is
|
|
`True`.
|
|
|
|
- If `input_features` is `None`, then `feature_names_in_` is
|
|
used as feature names in. If `feature_names_in_` is not defined,
|
|
then names are generated: `[x0, x1, ..., x(n_features_in_ - 1)]`.
|
|
- If `input_features` is an array-like, then `input_features` must
|
|
match `feature_names_in_` if `feature_names_in_` is defined.
|
|
|
|
If `passthrough` is `False`, then only the names of `estimators` are used
|
|
to generate the output feature names.
|
|
|
|
Returns
|
|
-------
|
|
feature_names_out : ndarray of str objects
|
|
Transformed feature names.
|
|
"""
|
|
input_features = _check_feature_names_in(
|
|
self, input_features, generate_names=self.passthrough
|
|
)
|
|
|
|
class_name = self.__class__.__name__.lower()
|
|
non_dropped_estimators = (
|
|
name for name, est in self.estimators if est != "drop"
|
|
)
|
|
meta_names = []
|
|
for est, n_features_out in zip(non_dropped_estimators, self._n_feature_outs):
|
|
if n_features_out == 1:
|
|
meta_names.append(f"{class_name}_{est}")
|
|
else:
|
|
meta_names.extend(
|
|
f"{class_name}_{est}{i}" for i in range(n_features_out)
|
|
)
|
|
|
|
if self.passthrough:
|
|
return np.concatenate((meta_names, input_features))
|
|
|
|
return np.asarray(meta_names, dtype=object)
|
|
|
|
@available_if(_estimator_has("predict"))
|
|
def predict(self, X, **predict_params):
|
|
"""Predict target for X.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
**predict_params : dict of str -> obj
|
|
Parameters to the `predict` called by the `final_estimator`. Note
|
|
that this may be used to return uncertainties from some estimators
|
|
with `return_std` or `return_cov`. Be aware that it will only
|
|
accounts for uncertainty in the final estimator.
|
|
|
|
Returns
|
|
-------
|
|
y_pred : ndarray of shape (n_samples,) or (n_samples, n_output)
|
|
Predicted targets.
|
|
"""
|
|
|
|
check_is_fitted(self)
|
|
return self.final_estimator_.predict(self.transform(X), **predict_params)
|
|
|
|
def _sk_visual_block_with_final_estimator(self, final_estimator):
|
|
names, estimators = zip(*self.estimators)
|
|
parallel = _VisualBlock("parallel", estimators, names=names, dash_wrapped=False)
|
|
|
|
# final estimator is wrapped in a parallel block to show the label:
|
|
# 'final_estimator' in the html repr
|
|
final_block = _VisualBlock(
|
|
"parallel", [final_estimator], names=["final_estimator"], dash_wrapped=False
|
|
)
|
|
return _VisualBlock("serial", (parallel, final_block), dash_wrapped=False)
|
|
|
|
|
|
class StackingClassifier(ClassifierMixin, _BaseStacking):
|
|
"""Stack of estimators with a final classifier.
|
|
|
|
Stacked generalization consists in stacking the output of individual
|
|
estimator and use a classifier to compute the final prediction. Stacking
|
|
allows to use the strength of each individual estimator by using their
|
|
output as input of a final estimator.
|
|
|
|
Note that `estimators_` are fitted on the full `X` while `final_estimator_`
|
|
is trained using cross-validated predictions of the base estimators using
|
|
`cross_val_predict`.
|
|
|
|
Read more in the :ref:`User Guide <stacking>`.
|
|
|
|
.. versionadded:: 0.22
|
|
|
|
Parameters
|
|
----------
|
|
estimators : list of (str, estimator)
|
|
Base estimators which will be stacked together. Each element of the
|
|
list is defined as a tuple of string (i.e. name) and an estimator
|
|
instance. An estimator can be set to 'drop' using `set_params`.
|
|
|
|
The type of estimator is generally expected to be a classifier.
|
|
However, one can pass a regressor for some use case (e.g. ordinal
|
|
regression).
|
|
|
|
final_estimator : estimator, default=None
|
|
A classifier which will be used to combine the base estimators.
|
|
The default classifier is a
|
|
:class:`~sklearn.linear_model.LogisticRegression`.
|
|
|
|
cv : int, cross-validation generator, iterable, or "prefit", default=None
|
|
Determines the cross-validation splitting strategy used in
|
|
`cross_val_predict` to train `final_estimator`. Possible inputs for
|
|
cv are:
|
|
|
|
* None, to use the default 5-fold cross validation,
|
|
* integer, to specify the number of folds in a (Stratified) KFold,
|
|
* An object to be used as a cross-validation generator,
|
|
* An iterable yielding train, test splits,
|
|
* `"prefit"` to assume the `estimators` are prefit. In this case, the
|
|
estimators will not be refitted.
|
|
|
|
For integer/None inputs, if the estimator is a classifier and y is
|
|
either binary or multiclass,
|
|
:class:`~sklearn.model_selection.StratifiedKFold` is used.
|
|
In all other cases, :class:`~sklearn.model_selection.KFold` is used.
|
|
These splitters are instantiated with `shuffle=False` so the splits
|
|
will be the same across calls.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
If "prefit" is passed, it is assumed that all `estimators` have
|
|
been fitted already. The `final_estimator_` is trained on the `estimators`
|
|
predictions on the full training set and are **not** cross validated
|
|
predictions. Please note that if the models have been trained on the same
|
|
data to train the stacking model, there is a very high risk of overfitting.
|
|
|
|
.. versionadded:: 1.1
|
|
The 'prefit' option was added in 1.1
|
|
|
|
.. note::
|
|
A larger number of split will provide no benefits if the number
|
|
of training samples is large enough. Indeed, the training time
|
|
will increase. ``cv`` is not used for model evaluation but for
|
|
prediction.
|
|
|
|
stack_method : {'auto', 'predict_proba', 'decision_function', 'predict'}, \
|
|
default='auto'
|
|
Methods called for each base estimator. It can be:
|
|
|
|
* if 'auto', it will try to invoke, for each estimator,
|
|
`'predict_proba'`, `'decision_function'` or `'predict'` in that
|
|
order.
|
|
* otherwise, one of `'predict_proba'`, `'decision_function'` or
|
|
`'predict'`. If the method is not implemented by the estimator, it
|
|
will raise an error.
|
|
|
|
n_jobs : int, default=None
|
|
The number of jobs to run in parallel all `estimators` `fit`.
|
|
`None` means 1 unless in a `joblib.parallel_backend` context. -1 means
|
|
using all processors. See Glossary for more details.
|
|
|
|
passthrough : bool, default=False
|
|
When False, only the predictions of estimators will be used as
|
|
training data for `final_estimator`. When True, the
|
|
`final_estimator` is trained on the predictions as well as the
|
|
original training data.
|
|
|
|
verbose : int, default=0
|
|
Verbosity level.
|
|
|
|
Attributes
|
|
----------
|
|
classes_ : ndarray of shape (n_classes,) or list of ndarray if `y` \
|
|
is of type `"multilabel-indicator"`.
|
|
Class labels.
|
|
|
|
estimators_ : list of estimators
|
|
The elements of the `estimators` parameter, having been fitted on the
|
|
training data. If an estimator has been set to `'drop'`, it
|
|
will not appear in `estimators_`. When `cv="prefit"`, `estimators_`
|
|
is set to `estimators` and is not fitted again.
|
|
|
|
named_estimators_ : :class:`~sklearn.utils.Bunch`
|
|
Attribute to access any fitted sub-estimators by name.
|
|
|
|
n_features_in_ : int
|
|
Number of features seen during :term:`fit`. Only defined if the
|
|
underlying classifier exposes such an attribute when fit.
|
|
|
|
.. versionadded:: 0.24
|
|
|
|
feature_names_in_ : ndarray of shape (`n_features_in_`,)
|
|
Names of features seen during :term:`fit`. Only defined if the
|
|
underlying estimators expose such an attribute when fit.
|
|
|
|
.. versionadded:: 1.0
|
|
|
|
final_estimator_ : estimator
|
|
The classifier which predicts given the output of `estimators_`.
|
|
|
|
stack_method_ : list of str
|
|
The method used by each base estimator.
|
|
|
|
See Also
|
|
--------
|
|
StackingRegressor : Stack of estimators with a final regressor.
|
|
|
|
Notes
|
|
-----
|
|
When `predict_proba` is used by each estimator (i.e. most of the time for
|
|
`stack_method='auto'` or specifically for `stack_method='predict_proba'`),
|
|
The first column predicted by each estimator will be dropped in the case
|
|
of a binary classification problem. Indeed, both feature will be perfectly
|
|
collinear.
|
|
|
|
In some cases (e.g. ordinal regression), one can pass regressors as the
|
|
first layer of the :class:`StackingClassifier`. However, note that `y` will
|
|
be internally encoded in a numerically increasing order or lexicographic
|
|
order. If this ordering is not adequate, one should manually numerically
|
|
encode the classes in the desired order.
|
|
|
|
References
|
|
----------
|
|
.. [1] Wolpert, David H. "Stacked generalization." Neural networks 5.2
|
|
(1992): 241-259.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn.datasets import load_iris
|
|
>>> from sklearn.ensemble import RandomForestClassifier
|
|
>>> from sklearn.svm import LinearSVC
|
|
>>> from sklearn.linear_model import LogisticRegression
|
|
>>> from sklearn.preprocessing import StandardScaler
|
|
>>> from sklearn.pipeline import make_pipeline
|
|
>>> from sklearn.ensemble import StackingClassifier
|
|
>>> X, y = load_iris(return_X_y=True)
|
|
>>> estimators = [
|
|
... ('rf', RandomForestClassifier(n_estimators=10, random_state=42)),
|
|
... ('svr', make_pipeline(StandardScaler(),
|
|
... LinearSVC(random_state=42)))
|
|
... ]
|
|
>>> clf = StackingClassifier(
|
|
... estimators=estimators, final_estimator=LogisticRegression()
|
|
... )
|
|
>>> from sklearn.model_selection import train_test_split
|
|
>>> X_train, X_test, y_train, y_test = train_test_split(
|
|
... X, y, stratify=y, random_state=42
|
|
... )
|
|
>>> clf.fit(X_train, y_train).score(X_test, y_test)
|
|
0.9...
|
|
"""
|
|
|
|
_parameter_constraints: dict = {
|
|
**_BaseStacking._parameter_constraints,
|
|
"stack_method": [
|
|
StrOptions({"auto", "predict_proba", "decision_function", "predict"})
|
|
],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
estimators,
|
|
final_estimator=None,
|
|
*,
|
|
cv=None,
|
|
stack_method="auto",
|
|
n_jobs=None,
|
|
passthrough=False,
|
|
verbose=0,
|
|
):
|
|
super().__init__(
|
|
estimators=estimators,
|
|
final_estimator=final_estimator,
|
|
cv=cv,
|
|
stack_method=stack_method,
|
|
n_jobs=n_jobs,
|
|
passthrough=passthrough,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def _validate_final_estimator(self):
|
|
self._clone_final_estimator(default=LogisticRegression())
|
|
if not is_classifier(self.final_estimator_):
|
|
raise ValueError(
|
|
"'final_estimator' parameter should be a classifier. Got {}".format(
|
|
self.final_estimator_
|
|
)
|
|
)
|
|
|
|
def _validate_estimators(self):
|
|
"""Overload the method of `_BaseHeterogeneousEnsemble` to be more
|
|
lenient towards the type of `estimators`.
|
|
|
|
Regressors can be accepted for some cases such as ordinal regression.
|
|
"""
|
|
if len(self.estimators) == 0:
|
|
raise ValueError(
|
|
"Invalid 'estimators' attribute, 'estimators' should be a "
|
|
"non-empty list of (string, estimator) tuples."
|
|
)
|
|
names, estimators = zip(*self.estimators)
|
|
self._validate_names(names)
|
|
|
|
has_estimator = any(est != "drop" for est in estimators)
|
|
if not has_estimator:
|
|
raise ValueError(
|
|
"All estimators are dropped. At least one is required "
|
|
"to be an estimator."
|
|
)
|
|
|
|
return names, estimators
|
|
|
|
def fit(self, X, y, sample_weight=None):
|
|
"""Fit the estimators.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
y : array-like of shape (n_samples,)
|
|
Target values. Note that `y` will be internally encoded in
|
|
numerically increasing order or lexicographic order. If the order
|
|
matter (e.g. for ordinal regression), one should numerically encode
|
|
the target `y` before calling :term:`fit`.
|
|
|
|
sample_weight : array-like of shape (n_samples,), default=None
|
|
Sample weights. If None, then samples are equally weighted.
|
|
Note that this is supported only if all underlying estimators
|
|
support sample weights.
|
|
|
|
Returns
|
|
-------
|
|
self : object
|
|
Returns a fitted instance of estimator.
|
|
"""
|
|
check_classification_targets(y)
|
|
if type_of_target(y) == "multilabel-indicator":
|
|
self._label_encoder = [LabelEncoder().fit(yk) for yk in y.T]
|
|
self.classes_ = [le.classes_ for le in self._label_encoder]
|
|
y_encoded = np.array(
|
|
[
|
|
self._label_encoder[target_idx].transform(target)
|
|
for target_idx, target in enumerate(y.T)
|
|
]
|
|
).T
|
|
else:
|
|
self._label_encoder = LabelEncoder().fit(y)
|
|
self.classes_ = self._label_encoder.classes_
|
|
y_encoded = self._label_encoder.transform(y)
|
|
return super().fit(X, y_encoded, sample_weight)
|
|
|
|
@available_if(_estimator_has("predict"))
|
|
def predict(self, X, **predict_params):
|
|
"""Predict target for X.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
**predict_params : dict of str -> obj
|
|
Parameters to the `predict` called by the `final_estimator`. Note
|
|
that this may be used to return uncertainties from some estimators
|
|
with `return_std` or `return_cov`. Be aware that it will only
|
|
accounts for uncertainty in the final estimator.
|
|
|
|
Returns
|
|
-------
|
|
y_pred : ndarray of shape (n_samples,) or (n_samples, n_output)
|
|
Predicted targets.
|
|
"""
|
|
y_pred = super().predict(X, **predict_params)
|
|
if isinstance(self._label_encoder, list):
|
|
# Handle the multilabel-indicator case
|
|
y_pred = np.array(
|
|
[
|
|
self._label_encoder[target_idx].inverse_transform(target)
|
|
for target_idx, target in enumerate(y_pred.T)
|
|
]
|
|
).T
|
|
else:
|
|
y_pred = self._label_encoder.inverse_transform(y_pred)
|
|
return y_pred
|
|
|
|
@available_if(_estimator_has("predict_proba"))
|
|
def predict_proba(self, X):
|
|
"""Predict class probabilities for `X` using the final estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
Returns
|
|
-------
|
|
probabilities : ndarray of shape (n_samples, n_classes) or \
|
|
list of ndarray of shape (n_output,)
|
|
The class probabilities of the input samples.
|
|
"""
|
|
check_is_fitted(self)
|
|
y_pred = self.final_estimator_.predict_proba(self.transform(X))
|
|
|
|
if isinstance(self._label_encoder, list):
|
|
# Handle the multilabel-indicator cases
|
|
y_pred = np.array([preds[:, 0] for preds in y_pred]).T
|
|
return y_pred
|
|
|
|
@available_if(_estimator_has("decision_function"))
|
|
def decision_function(self, X):
|
|
"""Decision function for samples in `X` using the final estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
Returns
|
|
-------
|
|
decisions : ndarray of shape (n_samples,), (n_samples, n_classes), \
|
|
or (n_samples, n_classes * (n_classes-1) / 2)
|
|
The decision function computed the final estimator.
|
|
"""
|
|
check_is_fitted(self)
|
|
return self.final_estimator_.decision_function(self.transform(X))
|
|
|
|
def transform(self, X):
|
|
"""Return class labels or probabilities for X for each estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
Returns
|
|
-------
|
|
y_preds : ndarray of shape (n_samples, n_estimators) or \
|
|
(n_samples, n_classes * n_estimators)
|
|
Prediction outputs for each estimator.
|
|
"""
|
|
return self._transform(X)
|
|
|
|
def _sk_visual_block_(self):
|
|
# If final_estimator's default changes then this should be
|
|
# updated.
|
|
if self.final_estimator is None:
|
|
final_estimator = LogisticRegression()
|
|
else:
|
|
final_estimator = self.final_estimator
|
|
return super()._sk_visual_block_with_final_estimator(final_estimator)
|
|
|
|
|
|
class StackingRegressor(RegressorMixin, _BaseStacking):
|
|
"""Stack of estimators with a final regressor.
|
|
|
|
Stacked generalization consists in stacking the output of individual
|
|
estimator and use a regressor to compute the final prediction. Stacking
|
|
allows to use the strength of each individual estimator by using their
|
|
output as input of a final estimator.
|
|
|
|
Note that `estimators_` are fitted on the full `X` while `final_estimator_`
|
|
is trained using cross-validated predictions of the base estimators using
|
|
`cross_val_predict`.
|
|
|
|
Read more in the :ref:`User Guide <stacking>`.
|
|
|
|
.. versionadded:: 0.22
|
|
|
|
Parameters
|
|
----------
|
|
estimators : list of (str, estimator)
|
|
Base estimators which will be stacked together. Each element of the
|
|
list is defined as a tuple of string (i.e. name) and an estimator
|
|
instance. An estimator can be set to 'drop' using `set_params`.
|
|
|
|
final_estimator : estimator, default=None
|
|
A regressor which will be used to combine the base estimators.
|
|
The default regressor is a :class:`~sklearn.linear_model.RidgeCV`.
|
|
|
|
cv : int, cross-validation generator, iterable, or "prefit", default=None
|
|
Determines the cross-validation splitting strategy used in
|
|
`cross_val_predict` to train `final_estimator`. Possible inputs for
|
|
cv are:
|
|
|
|
* None, to use the default 5-fold cross validation,
|
|
* integer, to specify the number of folds in a (Stratified) KFold,
|
|
* An object to be used as a cross-validation generator,
|
|
* An iterable yielding train, test splits.
|
|
* "prefit" to assume the `estimators` are prefit, and skip cross validation
|
|
|
|
For integer/None inputs, if the estimator is a classifier and y is
|
|
either binary or multiclass,
|
|
:class:`~sklearn.model_selection.StratifiedKFold` is used.
|
|
In all other cases, :class:`~sklearn.model_selection.KFold` is used.
|
|
These splitters are instantiated with `shuffle=False` so the splits
|
|
will be the same across calls.
|
|
|
|
Refer :ref:`User Guide <cross_validation>` for the various
|
|
cross-validation strategies that can be used here.
|
|
|
|
If "prefit" is passed, it is assumed that all `estimators` have
|
|
been fitted already. The `final_estimator_` is trained on the `estimators`
|
|
predictions on the full training set and are **not** cross validated
|
|
predictions. Please note that if the models have been trained on the same
|
|
data to train the stacking model, there is a very high risk of overfitting.
|
|
|
|
.. versionadded:: 1.1
|
|
The 'prefit' option was added in 1.1
|
|
|
|
.. note::
|
|
A larger number of split will provide no benefits if the number
|
|
of training samples is large enough. Indeed, the training time
|
|
will increase. ``cv`` is not used for model evaluation but for
|
|
prediction.
|
|
|
|
n_jobs : int, default=None
|
|
The number of jobs to run in parallel for `fit` of all `estimators`.
|
|
`None` means 1 unless in a `joblib.parallel_backend` context. -1 means
|
|
using all processors. See Glossary for more details.
|
|
|
|
passthrough : bool, default=False
|
|
When False, only the predictions of estimators will be used as
|
|
training data for `final_estimator`. When True, the
|
|
`final_estimator` is trained on the predictions as well as the
|
|
original training data.
|
|
|
|
verbose : int, default=0
|
|
Verbosity level.
|
|
|
|
Attributes
|
|
----------
|
|
estimators_ : list of estimator
|
|
The elements of the `estimators` parameter, having been fitted on the
|
|
training data. If an estimator has been set to `'drop'`, it
|
|
will not appear in `estimators_`. When `cv="prefit"`, `estimators_`
|
|
is set to `estimators` and is not fitted again.
|
|
|
|
named_estimators_ : :class:`~sklearn.utils.Bunch`
|
|
Attribute to access any fitted sub-estimators by name.
|
|
|
|
n_features_in_ : int
|
|
Number of features seen during :term:`fit`. Only defined if the
|
|
underlying regressor exposes such an attribute when fit.
|
|
|
|
.. versionadded:: 0.24
|
|
|
|
feature_names_in_ : ndarray of shape (`n_features_in_`,)
|
|
Names of features seen during :term:`fit`. Only defined if the
|
|
underlying estimators expose such an attribute when fit.
|
|
|
|
.. versionadded:: 1.0
|
|
|
|
final_estimator_ : estimator
|
|
The regressor to stacked the base estimators fitted.
|
|
|
|
stack_method_ : list of str
|
|
The method used by each base estimator.
|
|
|
|
See Also
|
|
--------
|
|
StackingClassifier : Stack of estimators with a final classifier.
|
|
|
|
References
|
|
----------
|
|
.. [1] Wolpert, David H. "Stacked generalization." Neural networks 5.2
|
|
(1992): 241-259.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn.datasets import load_diabetes
|
|
>>> from sklearn.linear_model import RidgeCV
|
|
>>> from sklearn.svm import LinearSVR
|
|
>>> from sklearn.ensemble import RandomForestRegressor
|
|
>>> from sklearn.ensemble import StackingRegressor
|
|
>>> X, y = load_diabetes(return_X_y=True)
|
|
>>> estimators = [
|
|
... ('lr', RidgeCV()),
|
|
... ('svr', LinearSVR(random_state=42))
|
|
... ]
|
|
>>> reg = StackingRegressor(
|
|
... estimators=estimators,
|
|
... final_estimator=RandomForestRegressor(n_estimators=10,
|
|
... random_state=42)
|
|
... )
|
|
>>> from sklearn.model_selection import train_test_split
|
|
>>> X_train, X_test, y_train, y_test = train_test_split(
|
|
... X, y, random_state=42
|
|
... )
|
|
>>> reg.fit(X_train, y_train).score(X_test, y_test)
|
|
0.3...
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
estimators,
|
|
final_estimator=None,
|
|
*,
|
|
cv=None,
|
|
n_jobs=None,
|
|
passthrough=False,
|
|
verbose=0,
|
|
):
|
|
super().__init__(
|
|
estimators=estimators,
|
|
final_estimator=final_estimator,
|
|
cv=cv,
|
|
stack_method="predict",
|
|
n_jobs=n_jobs,
|
|
passthrough=passthrough,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def _validate_final_estimator(self):
|
|
self._clone_final_estimator(default=RidgeCV())
|
|
if not is_regressor(self.final_estimator_):
|
|
raise ValueError(
|
|
"'final_estimator' parameter should be a regressor. Got {}".format(
|
|
self.final_estimator_
|
|
)
|
|
)
|
|
|
|
def fit(self, X, y, sample_weight=None):
|
|
"""Fit the estimators.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
y : array-like of shape (n_samples,)
|
|
Target values.
|
|
|
|
sample_weight : array-like of shape (n_samples,), default=None
|
|
Sample weights. If None, then samples are equally weighted.
|
|
Note that this is supported only if all underlying estimators
|
|
support sample weights.
|
|
|
|
Returns
|
|
-------
|
|
self : object
|
|
Returns a fitted instance.
|
|
"""
|
|
y = column_or_1d(y, warn=True)
|
|
return super().fit(X, y, sample_weight)
|
|
|
|
def transform(self, X):
|
|
"""Return the predictions for X for each estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
Returns
|
|
-------
|
|
y_preds : ndarray of shape (n_samples, n_estimators)
|
|
Prediction outputs for each estimator.
|
|
"""
|
|
return self._transform(X)
|
|
|
|
def fit_transform(self, X, y, sample_weight=None):
|
|
"""Fit the estimators and return the predictions for X for each estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
Training vectors, where `n_samples` is the number of samples and
|
|
`n_features` is the number of features.
|
|
|
|
y : array-like of shape (n_samples,)
|
|
Target values.
|
|
|
|
sample_weight : array-like of shape (n_samples,), default=None
|
|
Sample weights. If None, then samples are equally weighted.
|
|
Note that this is supported only if all underlying estimators
|
|
support sample weights.
|
|
|
|
Returns
|
|
-------
|
|
y_preds : ndarray of shape (n_samples, n_estimators)
|
|
Prediction outputs for each estimator.
|
|
"""
|
|
return super().fit_transform(X, y, sample_weight=sample_weight)
|
|
|
|
def _sk_visual_block_(self):
|
|
# If final_estimator's default changes then this should be
|
|
# updated.
|
|
if self.final_estimator is None:
|
|
final_estimator = RidgeCV()
|
|
else:
|
|
final_estimator = self.final_estimator
|
|
return super()._sk_visual_block_with_final_estimator(final_estimator)
|