Inzynierka/Lib/site-packages/sklearn/ensemble/_base.py
2023-06-02 12:51:02 +02:00

350 lines
11 KiB
Python

"""Base class for ensemble-based estimators."""
# Authors: Gilles Louppe
# License: BSD 3 clause
from abc import ABCMeta, abstractmethod
from typing import List
import warnings
import numpy as np
from joblib import effective_n_jobs
from ..base import clone
from ..base import is_classifier, is_regressor
from ..base import BaseEstimator
from ..base import MetaEstimatorMixin
from ..tree import (
DecisionTreeRegressor,
BaseDecisionTree,
DecisionTreeClassifier,
)
from ..utils import Bunch, _print_elapsed_time, deprecated
from ..utils import check_random_state
from ..utils.metaestimators import _BaseComposition
def _fit_single_estimator(
estimator, X, y, sample_weight=None, message_clsname=None, message=None
):
"""Private function used to fit an estimator within a job."""
if sample_weight is not None:
try:
with _print_elapsed_time(message_clsname, message):
estimator.fit(X, y, sample_weight=sample_weight)
except TypeError as exc:
if "unexpected keyword argument 'sample_weight'" in str(exc):
raise TypeError(
"Underlying estimator {} does not support sample weights.".format(
estimator.__class__.__name__
)
) from exc
raise
else:
with _print_elapsed_time(message_clsname, message):
estimator.fit(X, y)
return estimator
def _set_random_states(estimator, random_state=None):
"""Set fixed random_state parameters for an estimator.
Finds all parameters ending ``random_state`` and sets them to integers
derived from ``random_state``.
Parameters
----------
estimator : estimator supporting get/set_params
Estimator with potential randomness managed by random_state
parameters.
random_state : int, RandomState instance or None, default=None
Pseudo-random number generator to control the generation of the random
integers. Pass an int for reproducible output across multiple function
calls.
See :term:`Glossary <random_state>`.
Notes
-----
This does not necessarily set *all* ``random_state`` attributes that
control an estimator's randomness, only those accessible through
``estimator.get_params()``. ``random_state``s not controlled include
those belonging to:
* cross-validation splitters
* ``scipy.stats`` rvs
"""
random_state = check_random_state(random_state)
to_set = {}
for key in sorted(estimator.get_params(deep=True)):
if key == "random_state" or key.endswith("__random_state"):
to_set[key] = random_state.randint(np.iinfo(np.int32).max)
if to_set:
estimator.set_params(**to_set)
class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
"""Base class for all ensemble classes.
Warning: This class should not be used directly. Use derived classes
instead.
Parameters
----------
estimator : object
The base estimator from which the ensemble is built.
n_estimators : int, default=10
The number of estimators in the ensemble.
estimator_params : list of str, default=tuple()
The list of attributes to use as parameters when instantiating a
new base estimator. If none are given, default parameters are used.
base_estimator : object, default="deprecated"
Use `estimator` instead.
.. deprecated:: 1.2
`base_estimator` is deprecated and will be removed in 1.4.
Use `estimator` instead.
Attributes
----------
estimator_ : estimator
The base estimator from which the ensemble is grown.
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
.. deprecated:: 1.2
`base_estimator_` is deprecated and will be removed in 1.4.
Use `estimator_` instead.
estimators_ : list of estimators
The collection of fitted base estimators.
"""
# overwrite _required_parameters from MetaEstimatorMixin
_required_parameters: List[str] = []
@abstractmethod
def __init__(
self,
estimator=None,
*,
n_estimators=10,
estimator_params=tuple(),
base_estimator="deprecated",
):
# Set parameters
self.estimator = estimator
self.n_estimators = n_estimators
self.estimator_params = estimator_params
self.base_estimator = base_estimator
# Don't instantiate estimators now! Parameters of base_estimator might
# still change. Eg., when grid-searching with the nested object syntax.
# self.estimators_ needs to be filled by the derived classes in fit.
def _validate_estimator(self, default=None):
"""Check the base estimator.
Sets the `estimator_` attributes.
"""
if self.estimator is not None and (
self.base_estimator not in [None, "deprecated"]
):
raise ValueError(
"Both `estimator` and `base_estimator` were set. Only set `estimator`."
)
if self.estimator is not None:
self.estimator_ = self.estimator
elif self.base_estimator not in [None, "deprecated"]:
warnings.warn(
"`base_estimator` was renamed to `estimator` in version 1.2 and "
"will be removed in 1.4.",
FutureWarning,
)
self.estimator_ = self.base_estimator
else:
self.estimator_ = default
# TODO(1.4): remove
# mypy error: Decorated property not supported
@deprecated( # type: ignore
"Attribute `base_estimator_` was deprecated in version 1.2 and will be removed "
"in 1.4. Use `estimator_` instead."
)
@property
def base_estimator_(self):
"""Estimator used to grow the ensemble."""
return self.estimator_
def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators.
"""
estimator = clone(self.estimator_)
estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
# TODO(1.3): Remove
# max_features = 'auto' would cause warnings in every call to
# Tree.fit(..)
if isinstance(estimator, BaseDecisionTree):
if getattr(estimator, "max_features", None) == "auto":
if isinstance(estimator, DecisionTreeClassifier):
estimator.set_params(max_features="sqrt")
elif isinstance(estimator, DecisionTreeRegressor):
estimator.set_params(max_features=1.0)
if random_state is not None:
_set_random_states(estimator, random_state)
if append:
self.estimators_.append(estimator)
return estimator
def __len__(self):
"""Return the number of estimators in the ensemble."""
return len(self.estimators_)
def __getitem__(self, index):
"""Return the index'th estimator in the ensemble."""
return self.estimators_[index]
def __iter__(self):
"""Return iterator over estimators in the ensemble."""
return iter(self.estimators_)
def _partition_estimators(n_estimators, n_jobs):
"""Private function used to partition estimators between jobs."""
# Compute the number of jobs
n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
# Partition estimators between jobs
n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int)
n_estimators_per_job[: n_estimators % n_jobs] += 1
starts = np.cumsum(n_estimators_per_job)
return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
class _BaseHeterogeneousEnsemble(
MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta
):
"""Base class for heterogeneous ensemble of learners.
Parameters
----------
estimators : list of (str, estimator) tuples
The ensemble of estimators to use in the ensemble. Each element of the
list is defined as a tuple of string (i.e. name of the estimator) and
an estimator instance. An estimator can be set to `'drop'` using
`set_params`.
Attributes
----------
estimators_ : list of estimators
The elements of the estimators parameter, having been fitted on the
training data. If an estimator has been set to `'drop'`, it will not
appear in `estimators_`.
"""
_required_parameters = ["estimators"]
@property
def named_estimators(self):
"""Dictionary to access any fitted sub-estimators by name.
Returns
-------
:class:`~sklearn.utils.Bunch`
"""
return Bunch(**dict(self.estimators))
@abstractmethod
def __init__(self, estimators):
self.estimators = estimators
def _validate_estimators(self):
if len(self.estimators) == 0:
raise ValueError(
"Invalid 'estimators' attribute, 'estimators' should be a "
"non-empty list of (string, estimator) tuples."
)
names, estimators = zip(*self.estimators)
# defined by MetaEstimatorMixin
self._validate_names(names)
has_estimator = any(est != "drop" for est in estimators)
if not has_estimator:
raise ValueError(
"All estimators are dropped. At least one is required "
"to be an estimator."
)
is_estimator_type = is_classifier if is_classifier(self) else is_regressor
for est in estimators:
if est != "drop" and not is_estimator_type(est):
raise ValueError(
"The estimator {} should be a {}.".format(
est.__class__.__name__, is_estimator_type.__name__[3:]
)
)
return names, estimators
def set_params(self, **params):
"""
Set the parameters of an estimator from the ensemble.
Valid parameter keys can be listed with `get_params()`. Note that you
can directly set the parameters of the estimators contained in
`estimators`.
Parameters
----------
**params : keyword arguments
Specific parameters using e.g.
`set_params(parameter_name=new_value)`. In addition, to setting the
parameters of the estimator, the individual estimator of the
estimators can also be set, or can be removed by setting them to
'drop'.
Returns
-------
self : object
Estimator instance.
"""
super()._set_params("estimators", **params)
return self
def get_params(self, deep=True):
"""
Get the parameters of an estimator from the ensemble.
Returns the parameters given in the constructor as well as the
estimators contained within the `estimators` parameter.
Parameters
----------
deep : bool, default=True
Setting it to True gets the various estimators and the parameters
of the estimators as well.
Returns
-------
params : dict
Parameter and estimator names mapped to their values or parameter
names mapped to their values.
"""
return super()._get_params("estimators", deep=deep)