"""Base class for ensemble-based estimators.""" # Authors: Gilles Louppe # License: BSD 3 clause from abc import ABCMeta, abstractmethod import numbers from typing import List import numpy as np from joblib import effective_n_jobs from ..base import clone from ..base import is_classifier, is_regressor from ..base import BaseEstimator from ..base import MetaEstimatorMixin from ..utils import Bunch, _print_elapsed_time from ..utils import check_random_state from ..utils.metaestimators import _BaseComposition def _fit_single_estimator(estimator, X, y, sample_weight=None, message_clsname=None, message=None): """Private function used to fit an estimator within a job.""" if sample_weight is not None: try: with _print_elapsed_time(message_clsname, message): estimator.fit(X, y, sample_weight=sample_weight) except TypeError as exc: if "unexpected keyword argument 'sample_weight'" in str(exc): raise TypeError( "Underlying estimator {} does not support sample weights." .format(estimator.__class__.__name__) ) from exc raise else: with _print_elapsed_time(message_clsname, message): estimator.fit(X, y) return estimator def _set_random_states(estimator, random_state=None): """Set fixed random_state parameters for an estimator. Finds all parameters ending ``random_state`` and sets them to integers derived from ``random_state``. Parameters ---------- estimator : estimator supporting get/set_params Estimator with potential randomness managed by random_state parameters. random_state : int, RandomState instance or None, default=None Pseudo-random number generator to control the generation of the random integers. Pass an int for reproducible output across multiple function calls. See :term:`Glossary `. Notes ----- This does not necessarily set *all* ``random_state`` attributes that control an estimator's randomness, only those accessible through ``estimator.get_params()``. ``random_state``s not controlled include those belonging to: * cross-validation splitters * ``scipy.stats`` rvs """ random_state = check_random_state(random_state) to_set = {} for key in sorted(estimator.get_params(deep=True)): if key == 'random_state' or key.endswith('__random_state'): to_set[key] = random_state.randint(np.iinfo(np.int32).max) if to_set: estimator.set_params(**to_set) class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): """Base class for all ensemble classes. Warning: This class should not be used directly. Use derived classes instead. Parameters ---------- base_estimator : object The base estimator from which the ensemble is built. n_estimators : int, default=10 The number of estimators in the ensemble. estimator_params : list of str, default=tuple() The list of attributes to use as parameters when instantiating a new base estimator. If none are given, default parameters are used. Attributes ---------- base_estimator_ : estimator The base estimator from which the ensemble is grown. estimators_ : list of estimators The collection of fitted base estimators. """ # overwrite _required_parameters from MetaEstimatorMixin _required_parameters: List[str] = [] @abstractmethod def __init__(self, base_estimator, *, n_estimators=10, estimator_params=tuple()): # Set parameters self.base_estimator = base_estimator self.n_estimators = n_estimators self.estimator_params = estimator_params # Don't instantiate estimators now! Parameters of base_estimator might # still change. Eg., when grid-searching with the nested object syntax. # self.estimators_ needs to be filled by the derived classes in fit. def _validate_estimator(self, default=None): """Check the estimator and the n_estimator attribute. Sets the base_estimator_` attributes. """ if not isinstance(self.n_estimators, numbers.Integral): raise ValueError("n_estimators must be an integer, " "got {0}.".format(type(self.n_estimators))) if self.n_estimators <= 0: raise ValueError("n_estimators must be greater than zero, " "got {0}.".format(self.n_estimators)) if self.base_estimator is not None: self.base_estimator_ = self.base_estimator else: self.base_estimator_ = default if self.base_estimator_ is None: raise ValueError("base_estimator cannot be None") def _make_estimator(self, append=True, random_state=None): """Make and configure a copy of the `base_estimator_` attribute. Warning: This method should be used to properly instantiate new sub-estimators. """ estimator = clone(self.base_estimator_) estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params}) if random_state is not None: _set_random_states(estimator, random_state) if append: self.estimators_.append(estimator) return estimator def __len__(self): """Return the number of estimators in the ensemble.""" return len(self.estimators_) def __getitem__(self, index): """Return the index'th estimator in the ensemble.""" return self.estimators_[index] def __iter__(self): """Return iterator over estimators in the ensemble.""" return iter(self.estimators_) def _partition_estimators(n_estimators, n_jobs): """Private function used to partition estimators between jobs.""" # Compute the number of jobs n_jobs = min(effective_n_jobs(n_jobs), n_estimators) # Partition estimators between jobs n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int) n_estimators_per_job[:n_estimators % n_jobs] += 1 starts = np.cumsum(n_estimators_per_job) return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist() class _BaseHeterogeneousEnsemble(MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta): """Base class for heterogeneous ensemble of learners. Parameters ---------- estimators : list of (str, estimator) tuples The ensemble of estimators to use in the ensemble. Each element of the list is defined as a tuple of string (i.e. name of the estimator) and an estimator instance. An estimator can be set to `'drop'` using `set_params`. Attributes ---------- estimators_ : list of estimators The elements of the estimators parameter, having been fitted on the training data. If an estimator has been set to `'drop'`, it will not appear in `estimators_`. """ _required_parameters = ['estimators'] @property def named_estimators(self): return Bunch(**dict(self.estimators)) @abstractmethod def __init__(self, estimators): self.estimators = estimators def _validate_estimators(self): if self.estimators is None or len(self.estimators) == 0: raise ValueError( "Invalid 'estimators' attribute, 'estimators' should be a list" " of (string, estimator) tuples." ) names, estimators = zip(*self.estimators) # defined by MetaEstimatorMixin self._validate_names(names) has_estimator = any(est != 'drop' for est in estimators) if not has_estimator: raise ValueError( "All estimators are dropped. At least one is required " "to be an estimator." ) is_estimator_type = (is_classifier if is_classifier(self) else is_regressor) for est in estimators: if est != 'drop' and not is_estimator_type(est): raise ValueError( "The estimator {} should be a {}.".format( est.__class__.__name__, is_estimator_type.__name__[3:] ) ) return names, estimators def set_params(self, **params): """ Set the parameters of an estimator from the ensemble. Valid parameter keys can be listed with `get_params()`. Note that you can directly set the parameters of the estimators contained in `estimators`. Parameters ---------- **params : keyword arguments Specific parameters using e.g. `set_params(parameter_name=new_value)`. In addition, to setting the parameters of the estimator, the individual estimator of the estimators can also be set, or can be removed by setting them to 'drop'. """ super()._set_params('estimators', **params) return self def get_params(self, deep=True): """ Get the parameters of an estimator from the ensemble. Returns the parameters given in the constructor as well as the estimators contained within the `estimators` parameter. Parameters ---------- deep : bool, default=True Setting it to True gets the various estimators and the parameters of the estimators as well. """ return super()._get_params('estimators', deep=deep)