1275 lines
45 KiB
Python
1275 lines
45 KiB
Python
|
"""Weight Boosting.
|
||
|
|
||
|
This module contains weight boosting estimators for both classification and
|
||
|
regression.
|
||
|
|
||
|
The module structure is the following:
|
||
|
|
||
|
- The `BaseWeightBoosting` base class implements a common ``fit`` method
|
||
|
for all the estimators in the module. Regression and classification
|
||
|
only differ from each other in the loss function that is optimized.
|
||
|
|
||
|
- :class:`~sklearn.ensemble.AdaBoostClassifier` implements adaptive boosting
|
||
|
(AdaBoost-SAMME) for classification problems.
|
||
|
|
||
|
- :class:`~sklearn.ensemble.AdaBoostRegressor` implements adaptive boosting
|
||
|
(AdaBoost.R2) for regression problems.
|
||
|
"""
|
||
|
|
||
|
# Authors: Noel Dawe <noel@dawe.me>
|
||
|
# Gilles Louppe <g.louppe@gmail.com>
|
||
|
# Hamzeh Alsalhi <ha258@cornell.edu>
|
||
|
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||
|
#
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import warnings
|
||
|
from abc import ABCMeta, abstractmethod
|
||
|
from numbers import Integral, Real
|
||
|
|
||
|
import numpy as np
|
||
|
from scipy.special import xlogy
|
||
|
|
||
|
from ..base import (
|
||
|
ClassifierMixin,
|
||
|
RegressorMixin,
|
||
|
_fit_context,
|
||
|
is_classifier,
|
||
|
is_regressor,
|
||
|
)
|
||
|
from ..metrics import accuracy_score, r2_score
|
||
|
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||
|
from ..utils import _safe_indexing, check_random_state
|
||
|
from ..utils._param_validation import HasMethods, Interval, StrOptions
|
||
|
from ..utils.extmath import softmax, stable_cumsum
|
||
|
from ..utils.metadata_routing import (
|
||
|
_raise_for_unsupported_routing,
|
||
|
_RoutingNotSupportedMixin,
|
||
|
)
|
||
|
from ..utils.validation import (
|
||
|
_check_sample_weight,
|
||
|
_num_samples,
|
||
|
check_is_fitted,
|
||
|
has_fit_parameter,
|
||
|
)
|
||
|
from ._base import BaseEnsemble
|
||
|
|
||
|
__all__ = [
|
||
|
"AdaBoostClassifier",
|
||
|
"AdaBoostRegressor",
|
||
|
]
|
||
|
|
||
|
|
||
|
class BaseWeightBoosting(BaseEnsemble, metaclass=ABCMeta):
|
||
|
"""Base class for AdaBoost estimators.
|
||
|
|
||
|
Warning: This class should not be used directly. Use derived classes
|
||
|
instead.
|
||
|
"""
|
||
|
|
||
|
_parameter_constraints: dict = {
|
||
|
"estimator": [HasMethods(["fit", "predict"]), None],
|
||
|
"n_estimators": [Interval(Integral, 1, None, closed="left")],
|
||
|
"learning_rate": [Interval(Real, 0, None, closed="neither")],
|
||
|
"random_state": ["random_state"],
|
||
|
}
|
||
|
|
||
|
@abstractmethod
|
||
|
def __init__(
|
||
|
self,
|
||
|
estimator=None,
|
||
|
*,
|
||
|
n_estimators=50,
|
||
|
estimator_params=tuple(),
|
||
|
learning_rate=1.0,
|
||
|
random_state=None,
|
||
|
):
|
||
|
super().__init__(
|
||
|
estimator=estimator,
|
||
|
n_estimators=n_estimators,
|
||
|
estimator_params=estimator_params,
|
||
|
)
|
||
|
|
||
|
self.learning_rate = learning_rate
|
||
|
self.random_state = random_state
|
||
|
|
||
|
def _check_X(self, X):
|
||
|
# Only called to validate X in non-fit methods, therefore reset=False
|
||
|
return self._validate_data(
|
||
|
X,
|
||
|
accept_sparse=["csr", "csc"],
|
||
|
ensure_2d=True,
|
||
|
allow_nd=True,
|
||
|
dtype=None,
|
||
|
reset=False,
|
||
|
)
|
||
|
|
||
|
@_fit_context(
|
||
|
# AdaBoost*.estimator is not validated yet
|
||
|
prefer_skip_nested_validation=False
|
||
|
)
|
||
|
def fit(self, X, y, sample_weight=None):
|
||
|
"""Build a boosted classifier/regressor from the training set (X, y).
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
y : array-like of shape (n_samples,)
|
||
|
The target values.
|
||
|
|
||
|
sample_weight : array-like of shape (n_samples,), default=None
|
||
|
Sample weights. If None, the sample weights are initialized to
|
||
|
1 / n_samples.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
self : object
|
||
|
Fitted estimator.
|
||
|
"""
|
||
|
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
|
||
|
X, y = self._validate_data(
|
||
|
X,
|
||
|
y,
|
||
|
accept_sparse=["csr", "csc"],
|
||
|
ensure_2d=True,
|
||
|
allow_nd=True,
|
||
|
dtype=None,
|
||
|
y_numeric=is_regressor(self),
|
||
|
)
|
||
|
|
||
|
sample_weight = _check_sample_weight(
|
||
|
sample_weight, X, np.float64, copy=True, only_non_negative=True
|
||
|
)
|
||
|
sample_weight /= sample_weight.sum()
|
||
|
|
||
|
# Check parameters
|
||
|
self._validate_estimator()
|
||
|
|
||
|
# Clear any previous fit results
|
||
|
self.estimators_ = []
|
||
|
self.estimator_weights_ = np.zeros(self.n_estimators, dtype=np.float64)
|
||
|
self.estimator_errors_ = np.ones(self.n_estimators, dtype=np.float64)
|
||
|
|
||
|
# Initialization of the random number instance that will be used to
|
||
|
# generate a seed at each iteration
|
||
|
random_state = check_random_state(self.random_state)
|
||
|
epsilon = np.finfo(sample_weight.dtype).eps
|
||
|
|
||
|
zero_weight_mask = sample_weight == 0.0
|
||
|
for iboost in range(self.n_estimators):
|
||
|
# avoid extremely small sample weight, for details see issue #20320
|
||
|
sample_weight = np.clip(sample_weight, a_min=epsilon, a_max=None)
|
||
|
# do not clip sample weights that were exactly zero originally
|
||
|
sample_weight[zero_weight_mask] = 0.0
|
||
|
|
||
|
# Boosting step
|
||
|
sample_weight, estimator_weight, estimator_error = self._boost(
|
||
|
iboost, X, y, sample_weight, random_state
|
||
|
)
|
||
|
|
||
|
# Early termination
|
||
|
if sample_weight is None:
|
||
|
break
|
||
|
self.estimator_weights_[iboost] = estimator_weight
|
||
|
self.estimator_errors_[iboost] = estimator_error
|
||
|
|
||
|
# Stop if error is zero
|
||
|
if estimator_error == 0:
|
||
|
break
|
||
|
|
||
|
sample_weight_sum = np.sum(sample_weight)
|
||
|
|
||
|
if not np.isfinite(sample_weight_sum):
|
||
|
warnings.warn(
|
||
|
(
|
||
|
"Sample weights have reached infinite values,"
|
||
|
f" at iteration {iboost}, causing overflow. "
|
||
|
"Iterations stopped. Try lowering the learning rate."
|
||
|
),
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
break
|
||
|
|
||
|
# Stop if the sum of sample weights has become non-positive
|
||
|
if sample_weight_sum <= 0:
|
||
|
break
|
||
|
|
||
|
if iboost < self.n_estimators - 1:
|
||
|
# Normalize
|
||
|
sample_weight /= sample_weight_sum
|
||
|
|
||
|
return self
|
||
|
|
||
|
@abstractmethod
|
||
|
def _boost(self, iboost, X, y, sample_weight, random_state):
|
||
|
"""Implement a single boost.
|
||
|
|
||
|
Warning: This method needs to be overridden by subclasses.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
iboost : int
|
||
|
The index of the current boost iteration.
|
||
|
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
y : array-like of shape (n_samples,)
|
||
|
The target values (class labels).
|
||
|
|
||
|
sample_weight : array-like of shape (n_samples,)
|
||
|
The current sample weights.
|
||
|
|
||
|
random_state : RandomState
|
||
|
The current random number generator
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
sample_weight : array-like of shape (n_samples,) or None
|
||
|
The reweighted sample weights.
|
||
|
If None then boosting has terminated early.
|
||
|
|
||
|
estimator_weight : float
|
||
|
The weight for the current boost.
|
||
|
If None then boosting has terminated early.
|
||
|
|
||
|
error : float
|
||
|
The classification error for the current boost.
|
||
|
If None then boosting has terminated early.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def staged_score(self, X, y, sample_weight=None):
|
||
|
"""Return staged scores for X, y.
|
||
|
|
||
|
This generator method yields the ensemble score after each iteration of
|
||
|
boosting and therefore allows monitoring, such as to determine the
|
||
|
score on a test set after each boost.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
y : array-like of shape (n_samples,)
|
||
|
Labels for X.
|
||
|
|
||
|
sample_weight : array-like of shape (n_samples,), default=None
|
||
|
Sample weights.
|
||
|
|
||
|
Yields
|
||
|
------
|
||
|
z : float
|
||
|
"""
|
||
|
X = self._check_X(X)
|
||
|
|
||
|
for y_pred in self.staged_predict(X):
|
||
|
if is_classifier(self):
|
||
|
yield accuracy_score(y, y_pred, sample_weight=sample_weight)
|
||
|
else:
|
||
|
yield r2_score(y, y_pred, sample_weight=sample_weight)
|
||
|
|
||
|
@property
|
||
|
def feature_importances_(self):
|
||
|
"""The impurity-based feature importances.
|
||
|
|
||
|
The higher, the more important the feature.
|
||
|
The importance of a feature is computed as the (normalized)
|
||
|
total reduction of the criterion brought by that feature. It is also
|
||
|
known as the Gini importance.
|
||
|
|
||
|
Warning: impurity-based feature importances can be misleading for
|
||
|
high cardinality features (many unique values). See
|
||
|
:func:`sklearn.inspection.permutation_importance` as an alternative.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
feature_importances_ : ndarray of shape (n_features,)
|
||
|
The feature importances.
|
||
|
"""
|
||
|
if self.estimators_ is None or len(self.estimators_) == 0:
|
||
|
raise ValueError(
|
||
|
"Estimator not fitted, call `fit` before `feature_importances_`."
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
norm = self.estimator_weights_.sum()
|
||
|
return (
|
||
|
sum(
|
||
|
weight * clf.feature_importances_
|
||
|
for weight, clf in zip(self.estimator_weights_, self.estimators_)
|
||
|
)
|
||
|
/ norm
|
||
|
)
|
||
|
|
||
|
except AttributeError as e:
|
||
|
raise AttributeError(
|
||
|
"Unable to compute feature importances "
|
||
|
"since estimator does not have a "
|
||
|
"feature_importances_ attribute"
|
||
|
) from e
|
||
|
|
||
|
|
||
|
def _samme_proba(estimator, n_classes, X):
|
||
|
"""Calculate algorithm 4, step 2, equation c) of Zhu et al [1].
|
||
|
|
||
|
References
|
||
|
----------
|
||
|
.. [1] J. Zhu, H. Zou, S. Rosset, T. Hastie, "Multi-class AdaBoost", 2009.
|
||
|
|
||
|
"""
|
||
|
proba = estimator.predict_proba(X)
|
||
|
|
||
|
# Displace zero probabilities so the log is defined.
|
||
|
# Also fix negative elements which may occur with
|
||
|
# negative sample weights.
|
||
|
np.clip(proba, np.finfo(proba.dtype).eps, None, out=proba)
|
||
|
log_proba = np.log(proba)
|
||
|
|
||
|
return (n_classes - 1) * (
|
||
|
log_proba - (1.0 / n_classes) * log_proba.sum(axis=1)[:, np.newaxis]
|
||
|
)
|
||
|
|
||
|
|
||
|
class AdaBoostClassifier(
|
||
|
_RoutingNotSupportedMixin, ClassifierMixin, BaseWeightBoosting
|
||
|
):
|
||
|
"""An AdaBoost classifier.
|
||
|
|
||
|
An AdaBoost [1]_ classifier is a meta-estimator that begins by fitting a
|
||
|
classifier on the original dataset and then fits additional copies of the
|
||
|
classifier on the same dataset but where the weights of incorrectly
|
||
|
classified instances are adjusted such that subsequent classifiers focus
|
||
|
more on difficult cases.
|
||
|
|
||
|
This class implements the algorithm based on [2]_.
|
||
|
|
||
|
Read more in the :ref:`User Guide <adaboost>`.
|
||
|
|
||
|
.. versionadded:: 0.14
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
estimator : object, default=None
|
||
|
The base estimator from which the boosted ensemble is built.
|
||
|
Support for sample weighting is required, as well as proper
|
||
|
``classes_`` and ``n_classes_`` attributes. If ``None``, then
|
||
|
the base estimator is :class:`~sklearn.tree.DecisionTreeClassifier`
|
||
|
initialized with `max_depth=1`.
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
`base_estimator` was renamed to `estimator`.
|
||
|
|
||
|
n_estimators : int, default=50
|
||
|
The maximum number of estimators at which boosting is terminated.
|
||
|
In case of perfect fit, the learning procedure is stopped early.
|
||
|
Values must be in the range `[1, inf)`.
|
||
|
|
||
|
learning_rate : float, default=1.0
|
||
|
Weight applied to each classifier at each boosting iteration. A higher
|
||
|
learning rate increases the contribution of each classifier. There is
|
||
|
a trade-off between the `learning_rate` and `n_estimators` parameters.
|
||
|
Values must be in the range `(0.0, inf)`.
|
||
|
|
||
|
algorithm : {'SAMME', 'SAMME.R'}, default='SAMME.R'
|
||
|
If 'SAMME.R' then use the SAMME.R real boosting algorithm.
|
||
|
``estimator`` must support calculation of class probabilities.
|
||
|
If 'SAMME' then use the SAMME discrete boosting algorithm.
|
||
|
The SAMME.R algorithm typically converges faster than SAMME,
|
||
|
achieving a lower test error with fewer boosting iterations.
|
||
|
|
||
|
.. deprecated:: 1.4
|
||
|
`"SAMME.R"` is deprecated and will be removed in version 1.6.
|
||
|
'"SAMME"' will become the default.
|
||
|
|
||
|
random_state : int, RandomState instance or None, default=None
|
||
|
Controls the random seed given at each `estimator` at each
|
||
|
boosting iteration.
|
||
|
Thus, it is only used when `estimator` exposes a `random_state`.
|
||
|
Pass an int for reproducible output across multiple function calls.
|
||
|
See :term:`Glossary <random_state>`.
|
||
|
|
||
|
Attributes
|
||
|
----------
|
||
|
estimator_ : estimator
|
||
|
The base estimator from which the ensemble is grown.
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
`base_estimator_` was renamed to `estimator_`.
|
||
|
|
||
|
estimators_ : list of classifiers
|
||
|
The collection of fitted sub-estimators.
|
||
|
|
||
|
classes_ : ndarray of shape (n_classes,)
|
||
|
The classes labels.
|
||
|
|
||
|
n_classes_ : int
|
||
|
The number of classes.
|
||
|
|
||
|
estimator_weights_ : ndarray of floats
|
||
|
Weights for each estimator in the boosted ensemble.
|
||
|
|
||
|
estimator_errors_ : ndarray of floats
|
||
|
Classification error for each estimator in the boosted
|
||
|
ensemble.
|
||
|
|
||
|
feature_importances_ : ndarray of shape (n_features,)
|
||
|
The impurity-based feature importances if supported by the
|
||
|
``estimator`` (when based on decision trees).
|
||
|
|
||
|
Warning: impurity-based feature importances can be misleading for
|
||
|
high cardinality features (many unique values). See
|
||
|
:func:`sklearn.inspection.permutation_importance` as an alternative.
|
||
|
|
||
|
n_features_in_ : int
|
||
|
Number of features seen during :term:`fit`.
|
||
|
|
||
|
.. versionadded:: 0.24
|
||
|
|
||
|
feature_names_in_ : ndarray of shape (`n_features_in_`,)
|
||
|
Names of features seen during :term:`fit`. Defined only when `X`
|
||
|
has feature names that are all strings.
|
||
|
|
||
|
.. versionadded:: 1.0
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
AdaBoostRegressor : An AdaBoost regressor that begins by fitting a
|
||
|
regressor on the original dataset and then fits additional copies of
|
||
|
the regressor on the same dataset but where the weights of instances
|
||
|
are adjusted according to the error of the current prediction.
|
||
|
|
||
|
GradientBoostingClassifier : GB builds an additive model in a forward
|
||
|
stage-wise fashion. Regression trees are fit on the negative gradient
|
||
|
of the binomial or multinomial deviance loss function. Binary
|
||
|
classification is a special case where only a single regression tree is
|
||
|
induced.
|
||
|
|
||
|
sklearn.tree.DecisionTreeClassifier : A non-parametric supervised learning
|
||
|
method used for classification.
|
||
|
Creates a model that predicts the value of a target variable by
|
||
|
learning simple decision rules inferred from the data features.
|
||
|
|
||
|
References
|
||
|
----------
|
||
|
.. [1] Y. Freund, R. Schapire, "A Decision-Theoretic Generalization of
|
||
|
on-Line Learning and an Application to Boosting", 1995.
|
||
|
|
||
|
.. [2] :doi:`J. Zhu, H. Zou, S. Rosset, T. Hastie, "Multi-class adaboost."
|
||
|
Statistics and its Interface 2.3 (2009): 349-360.
|
||
|
<10.4310/SII.2009.v2.n3.a8>`
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from sklearn.ensemble import AdaBoostClassifier
|
||
|
>>> from sklearn.datasets import make_classification
|
||
|
>>> X, y = make_classification(n_samples=1000, n_features=4,
|
||
|
... n_informative=2, n_redundant=0,
|
||
|
... random_state=0, shuffle=False)
|
||
|
>>> clf = AdaBoostClassifier(n_estimators=100, algorithm="SAMME", random_state=0)
|
||
|
>>> clf.fit(X, y)
|
||
|
AdaBoostClassifier(algorithm='SAMME', n_estimators=100, random_state=0)
|
||
|
>>> clf.predict([[0, 0, 0, 0]])
|
||
|
array([1])
|
||
|
>>> clf.score(X, y)
|
||
|
0.96...
|
||
|
|
||
|
For a detailed example of using AdaBoost to fit a sequence of DecisionTrees
|
||
|
as weaklearners, please refer to
|
||
|
:ref:`sphx_glr_auto_examples_ensemble_plot_adaboost_multiclass.py`.
|
||
|
|
||
|
For a detailed example of using AdaBoost to fit a non-linearly seperable
|
||
|
classification dataset composed of two Gaussian quantiles clusters, please
|
||
|
refer to :ref:`sphx_glr_auto_examples_ensemble_plot_adaboost_twoclass.py`.
|
||
|
"""
|
||
|
|
||
|
# TODO(1.6): Modify _parameter_constraints for "algorithm" to only check
|
||
|
# for "SAMME"
|
||
|
_parameter_constraints: dict = {
|
||
|
**BaseWeightBoosting._parameter_constraints,
|
||
|
"algorithm": [
|
||
|
StrOptions({"SAMME", "SAMME.R"}),
|
||
|
],
|
||
|
}
|
||
|
|
||
|
# TODO(1.6): Change default "algorithm" value to "SAMME"
|
||
|
def __init__(
|
||
|
self,
|
||
|
estimator=None,
|
||
|
*,
|
||
|
n_estimators=50,
|
||
|
learning_rate=1.0,
|
||
|
algorithm="SAMME.R",
|
||
|
random_state=None,
|
||
|
):
|
||
|
super().__init__(
|
||
|
estimator=estimator,
|
||
|
n_estimators=n_estimators,
|
||
|
learning_rate=learning_rate,
|
||
|
random_state=random_state,
|
||
|
)
|
||
|
|
||
|
self.algorithm = algorithm
|
||
|
|
||
|
def _validate_estimator(self):
|
||
|
"""Check the estimator and set the estimator_ attribute."""
|
||
|
super()._validate_estimator(default=DecisionTreeClassifier(max_depth=1))
|
||
|
|
||
|
# TODO(1.6): Remove, as "SAMME.R" value for "algorithm" param will be
|
||
|
# removed in 1.6
|
||
|
# SAMME-R requires predict_proba-enabled base estimators
|
||
|
if self.algorithm != "SAMME":
|
||
|
warnings.warn(
|
||
|
(
|
||
|
"The SAMME.R algorithm (the default) is deprecated and will be"
|
||
|
" removed in 1.6. Use the SAMME algorithm to circumvent this"
|
||
|
" warning."
|
||
|
),
|
||
|
FutureWarning,
|
||
|
)
|
||
|
if not hasattr(self.estimator_, "predict_proba"):
|
||
|
raise TypeError(
|
||
|
"AdaBoostClassifier with algorithm='SAMME.R' requires "
|
||
|
"that the weak learner supports the calculation of class "
|
||
|
"probabilities with a predict_proba method.\n"
|
||
|
"Please change the base estimator or set "
|
||
|
"algorithm='SAMME' instead."
|
||
|
)
|
||
|
|
||
|
if not has_fit_parameter(self.estimator_, "sample_weight"):
|
||
|
raise ValueError(
|
||
|
f"{self.estimator.__class__.__name__} doesn't support sample_weight."
|
||
|
)
|
||
|
|
||
|
# TODO(1.6): Redefine the scope of the `_boost` and `_boost_discrete`
|
||
|
# functions to be the same since SAMME will be the default value for the
|
||
|
# "algorithm" parameter in version 1.6. Thus, a distinguishing function is
|
||
|
# no longer needed. (Or adjust code here, if another algorithm, shall be
|
||
|
# used instead of SAMME.R.)
|
||
|
def _boost(self, iboost, X, y, sample_weight, random_state):
|
||
|
"""Implement a single boost.
|
||
|
|
||
|
Perform a single boost according to the real multi-class SAMME.R
|
||
|
algorithm or to the discrete SAMME algorithm and return the updated
|
||
|
sample weights.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
iboost : int
|
||
|
The index of the current boost iteration.
|
||
|
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples.
|
||
|
|
||
|
y : array-like of shape (n_samples,)
|
||
|
The target values (class labels).
|
||
|
|
||
|
sample_weight : array-like of shape (n_samples,)
|
||
|
The current sample weights.
|
||
|
|
||
|
random_state : RandomState instance
|
||
|
The RandomState instance used if the base estimator accepts a
|
||
|
`random_state` attribute.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
sample_weight : array-like of shape (n_samples,) or None
|
||
|
The reweighted sample weights.
|
||
|
If None then boosting has terminated early.
|
||
|
|
||
|
estimator_weight : float
|
||
|
The weight for the current boost.
|
||
|
If None then boosting has terminated early.
|
||
|
|
||
|
estimator_error : float
|
||
|
The classification error for the current boost.
|
||
|
If None then boosting has terminated early.
|
||
|
"""
|
||
|
if self.algorithm == "SAMME.R":
|
||
|
return self._boost_real(iboost, X, y, sample_weight, random_state)
|
||
|
|
||
|
else: # elif self.algorithm == "SAMME":
|
||
|
return self._boost_discrete(iboost, X, y, sample_weight, random_state)
|
||
|
|
||
|
# TODO(1.6): Remove function. The `_boost_real` function won't be used any
|
||
|
# longer, because the SAMME.R algorithm will be deprecated in 1.6.
|
||
|
def _boost_real(self, iboost, X, y, sample_weight, random_state):
|
||
|
"""Implement a single boost using the SAMME.R real algorithm."""
|
||
|
estimator = self._make_estimator(random_state=random_state)
|
||
|
|
||
|
estimator.fit(X, y, sample_weight=sample_weight)
|
||
|
|
||
|
y_predict_proba = estimator.predict_proba(X)
|
||
|
|
||
|
if iboost == 0:
|
||
|
self.classes_ = getattr(estimator, "classes_", None)
|
||
|
self.n_classes_ = len(self.classes_)
|
||
|
|
||
|
y_predict = self.classes_.take(np.argmax(y_predict_proba, axis=1), axis=0)
|
||
|
|
||
|
# Instances incorrectly classified
|
||
|
incorrect = y_predict != y
|
||
|
|
||
|
# Error fraction
|
||
|
estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0))
|
||
|
|
||
|
# Stop if classification is perfect
|
||
|
if estimator_error <= 0:
|
||
|
return sample_weight, 1.0, 0.0
|
||
|
|
||
|
# Construct y coding as described in Zhu et al [2]:
|
||
|
#
|
||
|
# y_k = 1 if c == k else -1 / (K - 1)
|
||
|
#
|
||
|
# where K == n_classes_ and c, k in [0, K) are indices along the second
|
||
|
# axis of the y coding with c being the index corresponding to the true
|
||
|
# class label.
|
||
|
n_classes = self.n_classes_
|
||
|
classes = self.classes_
|
||
|
y_codes = np.array([-1.0 / (n_classes - 1), 1.0])
|
||
|
y_coding = y_codes.take(classes == y[:, np.newaxis])
|
||
|
|
||
|
# Displace zero probabilities so the log is defined.
|
||
|
# Also fix negative elements which may occur with
|
||
|
# negative sample weights.
|
||
|
proba = y_predict_proba # alias for readability
|
||
|
np.clip(proba, np.finfo(proba.dtype).eps, None, out=proba)
|
||
|
|
||
|
# Boost weight using multi-class AdaBoost SAMME.R alg
|
||
|
estimator_weight = (
|
||
|
-1.0
|
||
|
* self.learning_rate
|
||
|
* ((n_classes - 1.0) / n_classes)
|
||
|
* xlogy(y_coding, y_predict_proba).sum(axis=1)
|
||
|
)
|
||
|
|
||
|
# Only boost the weights if it will fit again
|
||
|
if not iboost == self.n_estimators - 1:
|
||
|
# Only boost positive weights
|
||
|
sample_weight *= np.exp(
|
||
|
estimator_weight * ((sample_weight > 0) | (estimator_weight < 0))
|
||
|
)
|
||
|
|
||
|
return sample_weight, 1.0, estimator_error
|
||
|
|
||
|
def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
|
||
|
"""Implement a single boost using the SAMME discrete algorithm."""
|
||
|
estimator = self._make_estimator(random_state=random_state)
|
||
|
|
||
|
estimator.fit(X, y, sample_weight=sample_weight)
|
||
|
|
||
|
y_predict = estimator.predict(X)
|
||
|
|
||
|
if iboost == 0:
|
||
|
self.classes_ = getattr(estimator, "classes_", None)
|
||
|
self.n_classes_ = len(self.classes_)
|
||
|
|
||
|
# Instances incorrectly classified
|
||
|
incorrect = y_predict != y
|
||
|
|
||
|
# Error fraction
|
||
|
estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0))
|
||
|
|
||
|
# Stop if classification is perfect
|
||
|
if estimator_error <= 0:
|
||
|
return sample_weight, 1.0, 0.0
|
||
|
|
||
|
n_classes = self.n_classes_
|
||
|
|
||
|
# Stop if the error is at least as bad as random guessing
|
||
|
if estimator_error >= 1.0 - (1.0 / n_classes):
|
||
|
self.estimators_.pop(-1)
|
||
|
if len(self.estimators_) == 0:
|
||
|
raise ValueError(
|
||
|
"BaseClassifier in AdaBoostClassifier "
|
||
|
"ensemble is worse than random, ensemble "
|
||
|
"can not be fit."
|
||
|
)
|
||
|
return None, None, None
|
||
|
|
||
|
# Boost weight using multi-class AdaBoost SAMME alg
|
||
|
estimator_weight = self.learning_rate * (
|
||
|
np.log((1.0 - estimator_error) / estimator_error) + np.log(n_classes - 1.0)
|
||
|
)
|
||
|
|
||
|
# Only boost the weights if it will fit again
|
||
|
if not iboost == self.n_estimators - 1:
|
||
|
# Only boost positive weights
|
||
|
sample_weight = np.exp(
|
||
|
np.log(sample_weight)
|
||
|
+ estimator_weight * incorrect * (sample_weight > 0)
|
||
|
)
|
||
|
|
||
|
return sample_weight, estimator_weight, estimator_error
|
||
|
|
||
|
def predict(self, X):
|
||
|
"""Predict classes for X.
|
||
|
|
||
|
The predicted class of an input sample is computed as the weighted mean
|
||
|
prediction of the classifiers in the ensemble.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
y : ndarray of shape (n_samples,)
|
||
|
The predicted classes.
|
||
|
"""
|
||
|
pred = self.decision_function(X)
|
||
|
|
||
|
if self.n_classes_ == 2:
|
||
|
return self.classes_.take(pred > 0, axis=0)
|
||
|
|
||
|
return self.classes_.take(np.argmax(pred, axis=1), axis=0)
|
||
|
|
||
|
def staged_predict(self, X):
|
||
|
"""Return staged predictions for X.
|
||
|
|
||
|
The predicted class of an input sample is computed as the weighted mean
|
||
|
prediction of the classifiers in the ensemble.
|
||
|
|
||
|
This generator method yields the ensemble prediction after each
|
||
|
iteration of boosting and therefore allows monitoring, such as to
|
||
|
determine the prediction on a test set after each boost.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : array-like of shape (n_samples, n_features)
|
||
|
The input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Yields
|
||
|
------
|
||
|
y : generator of ndarray of shape (n_samples,)
|
||
|
The predicted classes.
|
||
|
"""
|
||
|
X = self._check_X(X)
|
||
|
|
||
|
n_classes = self.n_classes_
|
||
|
classes = self.classes_
|
||
|
|
||
|
if n_classes == 2:
|
||
|
for pred in self.staged_decision_function(X):
|
||
|
yield np.array(classes.take(pred > 0, axis=0))
|
||
|
|
||
|
else:
|
||
|
for pred in self.staged_decision_function(X):
|
||
|
yield np.array(classes.take(np.argmax(pred, axis=1), axis=0))
|
||
|
|
||
|
def decision_function(self, X):
|
||
|
"""Compute the decision function of ``X``.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
score : ndarray of shape of (n_samples, k)
|
||
|
The decision function of the input samples. The order of
|
||
|
outputs is the same as that of the :term:`classes_` attribute.
|
||
|
Binary classification is a special cases with ``k == 1``,
|
||
|
otherwise ``k==n_classes``. For binary classification,
|
||
|
values closer to -1 or 1 mean more like the first or second
|
||
|
class in ``classes_``, respectively.
|
||
|
"""
|
||
|
check_is_fitted(self)
|
||
|
X = self._check_X(X)
|
||
|
|
||
|
n_classes = self.n_classes_
|
||
|
classes = self.classes_[:, np.newaxis]
|
||
|
|
||
|
# TODO(1.6): Remove, because "algorithm" param will be deprecated in 1.6
|
||
|
if self.algorithm == "SAMME.R":
|
||
|
# The weights are all 1. for SAMME.R
|
||
|
pred = sum(
|
||
|
_samme_proba(estimator, n_classes, X) for estimator in self.estimators_
|
||
|
)
|
||
|
else: # self.algorithm == "SAMME"
|
||
|
pred = sum(
|
||
|
np.where(
|
||
|
(estimator.predict(X) == classes).T,
|
||
|
w,
|
||
|
-1 / (n_classes - 1) * w,
|
||
|
)
|
||
|
for estimator, w in zip(self.estimators_, self.estimator_weights_)
|
||
|
)
|
||
|
|
||
|
pred /= self.estimator_weights_.sum()
|
||
|
if n_classes == 2:
|
||
|
pred[:, 0] *= -1
|
||
|
return pred.sum(axis=1)
|
||
|
return pred
|
||
|
|
||
|
def staged_decision_function(self, X):
|
||
|
"""Compute decision function of ``X`` for each boosting iteration.
|
||
|
|
||
|
This method allows monitoring (i.e. determine error on testing set)
|
||
|
after each boosting iteration.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Yields
|
||
|
------
|
||
|
score : generator of ndarray of shape (n_samples, k)
|
||
|
The decision function of the input samples. The order of
|
||
|
outputs is the same of that of the :term:`classes_` attribute.
|
||
|
Binary classification is a special cases with ``k == 1``,
|
||
|
otherwise ``k==n_classes``. For binary classification,
|
||
|
values closer to -1 or 1 mean more like the first or second
|
||
|
class in ``classes_``, respectively.
|
||
|
"""
|
||
|
check_is_fitted(self)
|
||
|
X = self._check_X(X)
|
||
|
|
||
|
n_classes = self.n_classes_
|
||
|
classes = self.classes_[:, np.newaxis]
|
||
|
pred = None
|
||
|
norm = 0.0
|
||
|
|
||
|
for weight, estimator in zip(self.estimator_weights_, self.estimators_):
|
||
|
norm += weight
|
||
|
|
||
|
# TODO(1.6): Remove, because "algorithm" param will be deprecated in
|
||
|
# 1.6
|
||
|
if self.algorithm == "SAMME.R":
|
||
|
# The weights are all 1. for SAMME.R
|
||
|
current_pred = _samme_proba(estimator, n_classes, X)
|
||
|
else: # elif self.algorithm == "SAMME":
|
||
|
current_pred = np.where(
|
||
|
(estimator.predict(X) == classes).T,
|
||
|
weight,
|
||
|
-1 / (n_classes - 1) * weight,
|
||
|
)
|
||
|
|
||
|
if pred is None:
|
||
|
pred = current_pred
|
||
|
else:
|
||
|
pred += current_pred
|
||
|
|
||
|
if n_classes == 2:
|
||
|
tmp_pred = np.copy(pred)
|
||
|
tmp_pred[:, 0] *= -1
|
||
|
yield (tmp_pred / norm).sum(axis=1)
|
||
|
else:
|
||
|
yield pred / norm
|
||
|
|
||
|
@staticmethod
|
||
|
def _compute_proba_from_decision(decision, n_classes):
|
||
|
"""Compute probabilities from the decision function.
|
||
|
|
||
|
This is based eq. (15) of [1] where:
|
||
|
p(y=c|X) = exp((1 / K-1) f_c(X)) / sum_k(exp((1 / K-1) f_k(X)))
|
||
|
= softmax((1 / K-1) * f(X))
|
||
|
|
||
|
References
|
||
|
----------
|
||
|
.. [1] J. Zhu, H. Zou, S. Rosset, T. Hastie, "Multi-class AdaBoost",
|
||
|
2009.
|
||
|
"""
|
||
|
if n_classes == 2:
|
||
|
decision = np.vstack([-decision, decision]).T / 2
|
||
|
else:
|
||
|
decision /= n_classes - 1
|
||
|
return softmax(decision, copy=False)
|
||
|
|
||
|
def predict_proba(self, X):
|
||
|
"""Predict class probabilities for X.
|
||
|
|
||
|
The predicted class probabilities of an input sample is computed as
|
||
|
the weighted mean predicted class probabilities of the classifiers
|
||
|
in the ensemble.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
p : ndarray of shape (n_samples, n_classes)
|
||
|
The class probabilities of the input samples. The order of
|
||
|
outputs is the same of that of the :term:`classes_` attribute.
|
||
|
"""
|
||
|
check_is_fitted(self)
|
||
|
n_classes = self.n_classes_
|
||
|
|
||
|
if n_classes == 1:
|
||
|
return np.ones((_num_samples(X), 1))
|
||
|
|
||
|
decision = self.decision_function(X)
|
||
|
return self._compute_proba_from_decision(decision, n_classes)
|
||
|
|
||
|
def staged_predict_proba(self, X):
|
||
|
"""Predict class probabilities for X.
|
||
|
|
||
|
The predicted class probabilities of an input sample is computed as
|
||
|
the weighted mean predicted class probabilities of the classifiers
|
||
|
in the ensemble.
|
||
|
|
||
|
This generator method yields the ensemble predicted class probabilities
|
||
|
after each iteration of boosting and therefore allows monitoring, such
|
||
|
as to determine the predicted class probabilities on a test set after
|
||
|
each boost.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Yields
|
||
|
------
|
||
|
p : generator of ndarray of shape (n_samples,)
|
||
|
The class probabilities of the input samples. The order of
|
||
|
outputs is the same of that of the :term:`classes_` attribute.
|
||
|
"""
|
||
|
|
||
|
n_classes = self.n_classes_
|
||
|
|
||
|
for decision in self.staged_decision_function(X):
|
||
|
yield self._compute_proba_from_decision(decision, n_classes)
|
||
|
|
||
|
def predict_log_proba(self, X):
|
||
|
"""Predict class log-probabilities for X.
|
||
|
|
||
|
The predicted class log-probabilities of an input sample is computed as
|
||
|
the weighted mean predicted class log-probabilities of the classifiers
|
||
|
in the ensemble.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
p : ndarray of shape (n_samples, n_classes)
|
||
|
The class probabilities of the input samples. The order of
|
||
|
outputs is the same of that of the :term:`classes_` attribute.
|
||
|
"""
|
||
|
return np.log(self.predict_proba(X))
|
||
|
|
||
|
|
||
|
class AdaBoostRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseWeightBoosting):
|
||
|
"""An AdaBoost regressor.
|
||
|
|
||
|
An AdaBoost [1] regressor is a meta-estimator that begins by fitting a
|
||
|
regressor on the original dataset and then fits additional copies of the
|
||
|
regressor on the same dataset but where the weights of instances are
|
||
|
adjusted according to the error of the current prediction. As such,
|
||
|
subsequent regressors focus more on difficult cases.
|
||
|
|
||
|
This class implements the algorithm known as AdaBoost.R2 [2].
|
||
|
|
||
|
Read more in the :ref:`User Guide <adaboost>`.
|
||
|
|
||
|
.. versionadded:: 0.14
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
estimator : object, default=None
|
||
|
The base estimator from which the boosted ensemble is built.
|
||
|
If ``None``, then the base estimator is
|
||
|
:class:`~sklearn.tree.DecisionTreeRegressor` initialized with
|
||
|
`max_depth=3`.
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
`base_estimator` was renamed to `estimator`.
|
||
|
|
||
|
n_estimators : int, default=50
|
||
|
The maximum number of estimators at which boosting is terminated.
|
||
|
In case of perfect fit, the learning procedure is stopped early.
|
||
|
Values must be in the range `[1, inf)`.
|
||
|
|
||
|
learning_rate : float, default=1.0
|
||
|
Weight applied to each regressor at each boosting iteration. A higher
|
||
|
learning rate increases the contribution of each regressor. There is
|
||
|
a trade-off between the `learning_rate` and `n_estimators` parameters.
|
||
|
Values must be in the range `(0.0, inf)`.
|
||
|
|
||
|
loss : {'linear', 'square', 'exponential'}, default='linear'
|
||
|
The loss function to use when updating the weights after each
|
||
|
boosting iteration.
|
||
|
|
||
|
random_state : int, RandomState instance or None, default=None
|
||
|
Controls the random seed given at each `estimator` at each
|
||
|
boosting iteration.
|
||
|
Thus, it is only used when `estimator` exposes a `random_state`.
|
||
|
In addition, it controls the bootstrap of the weights used to train the
|
||
|
`estimator` at each boosting iteration.
|
||
|
Pass an int for reproducible output across multiple function calls.
|
||
|
See :term:`Glossary <random_state>`.
|
||
|
|
||
|
Attributes
|
||
|
----------
|
||
|
estimator_ : estimator
|
||
|
The base estimator from which the ensemble is grown.
|
||
|
|
||
|
.. versionadded:: 1.2
|
||
|
`base_estimator_` was renamed to `estimator_`.
|
||
|
|
||
|
estimators_ : list of regressors
|
||
|
The collection of fitted sub-estimators.
|
||
|
|
||
|
estimator_weights_ : ndarray of floats
|
||
|
Weights for each estimator in the boosted ensemble.
|
||
|
|
||
|
estimator_errors_ : ndarray of floats
|
||
|
Regression error for each estimator in the boosted ensemble.
|
||
|
|
||
|
feature_importances_ : ndarray of shape (n_features,)
|
||
|
The impurity-based feature importances if supported by the
|
||
|
``estimator`` (when based on decision trees).
|
||
|
|
||
|
Warning: impurity-based feature importances can be misleading for
|
||
|
high cardinality features (many unique values). See
|
||
|
:func:`sklearn.inspection.permutation_importance` as an alternative.
|
||
|
|
||
|
n_features_in_ : int
|
||
|
Number of features seen during :term:`fit`.
|
||
|
|
||
|
.. versionadded:: 0.24
|
||
|
|
||
|
feature_names_in_ : ndarray of shape (`n_features_in_`,)
|
||
|
Names of features seen during :term:`fit`. Defined only when `X`
|
||
|
has feature names that are all strings.
|
||
|
|
||
|
.. versionadded:: 1.0
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
AdaBoostClassifier : An AdaBoost classifier.
|
||
|
GradientBoostingRegressor : Gradient Boosting Classification Tree.
|
||
|
sklearn.tree.DecisionTreeRegressor : A decision tree regressor.
|
||
|
|
||
|
References
|
||
|
----------
|
||
|
.. [1] Y. Freund, R. Schapire, "A Decision-Theoretic Generalization of
|
||
|
on-Line Learning and an Application to Boosting", 1995.
|
||
|
|
||
|
.. [2] H. Drucker, "Improving Regressors using Boosting Techniques", 1997.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from sklearn.ensemble import AdaBoostRegressor
|
||
|
>>> from sklearn.datasets import make_regression
|
||
|
>>> X, y = make_regression(n_features=4, n_informative=2,
|
||
|
... random_state=0, shuffle=False)
|
||
|
>>> regr = AdaBoostRegressor(random_state=0, n_estimators=100)
|
||
|
>>> regr.fit(X, y)
|
||
|
AdaBoostRegressor(n_estimators=100, random_state=0)
|
||
|
>>> regr.predict([[0, 0, 0, 0]])
|
||
|
array([4.7972...])
|
||
|
>>> regr.score(X, y)
|
||
|
0.9771...
|
||
|
"""
|
||
|
|
||
|
_parameter_constraints: dict = {
|
||
|
**BaseWeightBoosting._parameter_constraints,
|
||
|
"loss": [StrOptions({"linear", "square", "exponential"})],
|
||
|
}
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
estimator=None,
|
||
|
*,
|
||
|
n_estimators=50,
|
||
|
learning_rate=1.0,
|
||
|
loss="linear",
|
||
|
random_state=None,
|
||
|
):
|
||
|
super().__init__(
|
||
|
estimator=estimator,
|
||
|
n_estimators=n_estimators,
|
||
|
learning_rate=learning_rate,
|
||
|
random_state=random_state,
|
||
|
)
|
||
|
|
||
|
self.loss = loss
|
||
|
self.random_state = random_state
|
||
|
|
||
|
def _validate_estimator(self):
|
||
|
"""Check the estimator and set the estimator_ attribute."""
|
||
|
super()._validate_estimator(default=DecisionTreeRegressor(max_depth=3))
|
||
|
|
||
|
def _boost(self, iboost, X, y, sample_weight, random_state):
|
||
|
"""Implement a single boost for regression
|
||
|
|
||
|
Perform a single boost according to the AdaBoost.R2 algorithm and
|
||
|
return the updated sample weights.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
iboost : int
|
||
|
The index of the current boost iteration.
|
||
|
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples.
|
||
|
|
||
|
y : array-like of shape (n_samples,)
|
||
|
The target values (class labels in classification, real numbers in
|
||
|
regression).
|
||
|
|
||
|
sample_weight : array-like of shape (n_samples,)
|
||
|
The current sample weights.
|
||
|
|
||
|
random_state : RandomState
|
||
|
The RandomState instance used if the base estimator accepts a
|
||
|
`random_state` attribute.
|
||
|
Controls also the bootstrap of the weights used to train the weak
|
||
|
learner.
|
||
|
replacement.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
sample_weight : array-like of shape (n_samples,) or None
|
||
|
The reweighted sample weights.
|
||
|
If None then boosting has terminated early.
|
||
|
|
||
|
estimator_weight : float
|
||
|
The weight for the current boost.
|
||
|
If None then boosting has terminated early.
|
||
|
|
||
|
estimator_error : float
|
||
|
The regression error for the current boost.
|
||
|
If None then boosting has terminated early.
|
||
|
"""
|
||
|
estimator = self._make_estimator(random_state=random_state)
|
||
|
|
||
|
# Weighted sampling of the training set with replacement
|
||
|
bootstrap_idx = random_state.choice(
|
||
|
np.arange(_num_samples(X)),
|
||
|
size=_num_samples(X),
|
||
|
replace=True,
|
||
|
p=sample_weight,
|
||
|
)
|
||
|
|
||
|
# Fit on the bootstrapped sample and obtain a prediction
|
||
|
# for all samples in the training set
|
||
|
X_ = _safe_indexing(X, bootstrap_idx)
|
||
|
y_ = _safe_indexing(y, bootstrap_idx)
|
||
|
estimator.fit(X_, y_)
|
||
|
y_predict = estimator.predict(X)
|
||
|
|
||
|
error_vect = np.abs(y_predict - y)
|
||
|
sample_mask = sample_weight > 0
|
||
|
masked_sample_weight = sample_weight[sample_mask]
|
||
|
masked_error_vector = error_vect[sample_mask]
|
||
|
|
||
|
error_max = masked_error_vector.max()
|
||
|
if error_max != 0:
|
||
|
masked_error_vector /= error_max
|
||
|
|
||
|
if self.loss == "square":
|
||
|
masked_error_vector **= 2
|
||
|
elif self.loss == "exponential":
|
||
|
masked_error_vector = 1.0 - np.exp(-masked_error_vector)
|
||
|
|
||
|
# Calculate the average loss
|
||
|
estimator_error = (masked_sample_weight * masked_error_vector).sum()
|
||
|
|
||
|
if estimator_error <= 0:
|
||
|
# Stop if fit is perfect
|
||
|
return sample_weight, 1.0, 0.0
|
||
|
|
||
|
elif estimator_error >= 0.5:
|
||
|
# Discard current estimator only if it isn't the only one
|
||
|
if len(self.estimators_) > 1:
|
||
|
self.estimators_.pop(-1)
|
||
|
return None, None, None
|
||
|
|
||
|
beta = estimator_error / (1.0 - estimator_error)
|
||
|
|
||
|
# Boost weight using AdaBoost.R2 alg
|
||
|
estimator_weight = self.learning_rate * np.log(1.0 / beta)
|
||
|
|
||
|
if not iboost == self.n_estimators - 1:
|
||
|
sample_weight[sample_mask] *= np.power(
|
||
|
beta, (1.0 - masked_error_vector) * self.learning_rate
|
||
|
)
|
||
|
|
||
|
return sample_weight, estimator_weight, estimator_error
|
||
|
|
||
|
def _get_median_predict(self, X, limit):
|
||
|
# Evaluate predictions of all estimators
|
||
|
predictions = np.array([est.predict(X) for est in self.estimators_[:limit]]).T
|
||
|
|
||
|
# Sort the predictions
|
||
|
sorted_idx = np.argsort(predictions, axis=1)
|
||
|
|
||
|
# Find index of median prediction for each sample
|
||
|
weight_cdf = stable_cumsum(self.estimator_weights_[sorted_idx], axis=1)
|
||
|
median_or_above = weight_cdf >= 0.5 * weight_cdf[:, -1][:, np.newaxis]
|
||
|
median_idx = median_or_above.argmax(axis=1)
|
||
|
|
||
|
median_estimators = sorted_idx[np.arange(_num_samples(X)), median_idx]
|
||
|
|
||
|
# Return median predictions
|
||
|
return predictions[np.arange(_num_samples(X)), median_estimators]
|
||
|
|
||
|
def predict(self, X):
|
||
|
"""Predict regression value for X.
|
||
|
|
||
|
The predicted regression value of an input sample is computed
|
||
|
as the weighted median prediction of the regressors in the ensemble.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples. Sparse matrix can be CSC, CSR, COO,
|
||
|
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
y : ndarray of shape (n_samples,)
|
||
|
The predicted regression values.
|
||
|
"""
|
||
|
check_is_fitted(self)
|
||
|
X = self._check_X(X)
|
||
|
|
||
|
return self._get_median_predict(X, len(self.estimators_))
|
||
|
|
||
|
def staged_predict(self, X):
|
||
|
"""Return staged predictions for X.
|
||
|
|
||
|
The predicted regression value of an input sample is computed
|
||
|
as the weighted median prediction of the regressors in the ensemble.
|
||
|
|
||
|
This generator method yields the ensemble prediction after each
|
||
|
iteration of boosting and therefore allows monitoring, such as to
|
||
|
determine the prediction on a test set after each boost.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||
|
The training input samples.
|
||
|
|
||
|
Yields
|
||
|
------
|
||
|
y : generator of ndarray of shape (n_samples,)
|
||
|
The predicted regression values.
|
||
|
"""
|
||
|
check_is_fitted(self)
|
||
|
X = self._check_X(X)
|
||
|
|
||
|
for i, _ in enumerate(self.estimators_, 1):
|
||
|
yield self._get_median_predict(X, limit=i)
|