1049 lines
37 KiB
Python
1049 lines
37 KiB
Python
"""
|
|
The :mod:`sklearn.pipeline` module implements utilities to build a composite
|
|
estimator, as a chain of transforms and estimators.
|
|
"""
|
|
# Author: Edouard Duchesnay
|
|
# Gael Varoquaux
|
|
# Virgile Fritsch
|
|
# Alexandre Gramfort
|
|
# Lars Buitinck
|
|
# License: BSD
|
|
|
|
from collections import defaultdict
|
|
from itertools import islice
|
|
import warnings
|
|
|
|
import numpy as np
|
|
from scipy import sparse
|
|
from joblib import Parallel, delayed
|
|
|
|
from .base import clone, TransformerMixin
|
|
from .utils.metaestimators import if_delegate_has_method
|
|
from .utils import Bunch, _print_elapsed_time
|
|
from .utils.validation import check_memory
|
|
|
|
from .utils.metaestimators import _BaseComposition
|
|
|
|
__all__ = ['Pipeline', 'FeatureUnion', 'make_pipeline', 'make_union']
|
|
|
|
|
|
class Pipeline(_BaseComposition):
|
|
"""
|
|
Pipeline of transforms with a final estimator.
|
|
|
|
Sequentially apply a list of transforms and a final estimator.
|
|
Intermediate steps of the pipeline must be 'transforms', that is, they
|
|
must implement fit and transform methods.
|
|
The final estimator only needs to implement fit.
|
|
The transformers in the pipeline can be cached using ``memory`` argument.
|
|
|
|
The purpose of the pipeline is to assemble several steps that can be
|
|
cross-validated together while setting different parameters.
|
|
For this, it enables setting parameters of the various steps using their
|
|
names and the parameter name separated by a '__', as in the example below.
|
|
A step's estimator may be replaced entirely by setting the parameter
|
|
with its name to another estimator, or a transformer removed by setting
|
|
it to 'passthrough' or ``None``.
|
|
|
|
Read more in the :ref:`User Guide <pipeline>`.
|
|
|
|
.. versionadded:: 0.5
|
|
|
|
Parameters
|
|
----------
|
|
steps : list
|
|
List of (name, transform) tuples (implementing fit/transform) that are
|
|
chained, in the order in which they are chained, with the last object
|
|
an estimator.
|
|
|
|
memory : None, str or object with the joblib.Memory interface, optional
|
|
Used to cache the fitted transformers of the pipeline. By default,
|
|
no caching is performed. If a string is given, it is the path to
|
|
the caching directory. Enabling caching triggers a clone of
|
|
the transformers before fitting. Therefore, the transformer
|
|
instance given to the pipeline cannot be inspected
|
|
directly. Use the attribute ``named_steps`` or ``steps`` to
|
|
inspect estimators within the pipeline. Caching the
|
|
transformers is advantageous when fitting is time consuming.
|
|
|
|
verbose : bool, default=False
|
|
If True, the time elapsed while fitting each step will be printed as it
|
|
is completed.
|
|
|
|
Attributes
|
|
----------
|
|
named_steps : bunch object, a dictionary with attribute access
|
|
Read-only attribute to access any step parameter by user given name.
|
|
Keys are step names and values are steps parameters.
|
|
|
|
See Also
|
|
--------
|
|
sklearn.pipeline.make_pipeline : Convenience function for simplified
|
|
pipeline construction.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn import svm
|
|
>>> from sklearn.datasets import make_classification
|
|
>>> from sklearn.feature_selection import SelectKBest
|
|
>>> from sklearn.feature_selection import f_regression
|
|
>>> from sklearn.pipeline import Pipeline
|
|
>>> # generate some data to play with
|
|
>>> X, y = make_classification(
|
|
... n_informative=5, n_redundant=0, random_state=42)
|
|
>>> # ANOVA SVM-C
|
|
>>> anova_filter = SelectKBest(f_regression, k=5)
|
|
>>> clf = svm.SVC(kernel='linear')
|
|
>>> anova_svm = Pipeline([('anova', anova_filter), ('svc', clf)])
|
|
>>> # You can set the parameters using the names issued
|
|
>>> # For instance, fit using a k of 10 in the SelectKBest
|
|
>>> # and a parameter 'C' of the svm
|
|
>>> anova_svm.set_params(anova__k=10, svc__C=.1).fit(X, y)
|
|
Pipeline(steps=[('anova', SelectKBest(...)), ('svc', SVC(...))])
|
|
>>> prediction = anova_svm.predict(X)
|
|
>>> anova_svm.score(X, y)
|
|
0.83
|
|
>>> # getting the selected features chosen by anova_filter
|
|
>>> anova_svm['anova'].get_support()
|
|
array([False, False, True, True, False, False, True, True, False,
|
|
True, False, True, True, False, True, False, True, True,
|
|
False, False])
|
|
>>> # Another way to get selected features chosen by anova_filter
|
|
>>> anova_svm.named_steps.anova.get_support()
|
|
array([False, False, True, True, False, False, True, True, False,
|
|
True, False, True, True, False, True, False, True, True,
|
|
False, False])
|
|
>>> # Indexing can also be used to extract a sub-pipeline.
|
|
>>> sub_pipeline = anova_svm[:1]
|
|
>>> sub_pipeline
|
|
Pipeline(steps=[('anova', SelectKBest(...))])
|
|
>>> coef = anova_svm[-1].coef_
|
|
>>> anova_svm['svc'] is anova_svm[-1]
|
|
True
|
|
>>> coef.shape
|
|
(1, 10)
|
|
>>> sub_pipeline.inverse_transform(coef).shape
|
|
(1, 20)
|
|
"""
|
|
|
|
# BaseEstimator interface
|
|
_required_parameters = ['steps']
|
|
|
|
def __init__(self, steps, memory=None, verbose=False):
|
|
self.steps = steps
|
|
self.memory = memory
|
|
self.verbose = verbose
|
|
self._validate_steps()
|
|
|
|
def get_params(self, deep=True):
|
|
"""Get parameters for this estimator.
|
|
|
|
Parameters
|
|
----------
|
|
deep : boolean, optional
|
|
If True, will return the parameters for this estimator and
|
|
contained subobjects that are estimators.
|
|
|
|
Returns
|
|
-------
|
|
params : mapping of string to any
|
|
Parameter names mapped to their values.
|
|
"""
|
|
return self._get_params('steps', deep=deep)
|
|
|
|
def set_params(self, **kwargs):
|
|
"""Set the parameters of this estimator.
|
|
|
|
Valid parameter keys can be listed with ``get_params()``.
|
|
|
|
Returns
|
|
-------
|
|
self
|
|
"""
|
|
self._set_params('steps', **kwargs)
|
|
return self
|
|
|
|
def _validate_steps(self):
|
|
names, estimators = zip(*self.steps)
|
|
|
|
# validate names
|
|
self._validate_names(names)
|
|
|
|
# validate estimators
|
|
transformers = estimators[:-1]
|
|
estimator = estimators[-1]
|
|
|
|
for t in transformers:
|
|
if t is None or t == 'passthrough':
|
|
continue
|
|
if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not
|
|
hasattr(t, "transform")):
|
|
raise TypeError("All intermediate steps should be "
|
|
"transformers and implement fit and transform "
|
|
"or be the string 'passthrough' "
|
|
"'%s' (type %s) doesn't" % (t, type(t)))
|
|
|
|
# We allow last estimator to be None as an identity transformation
|
|
if (estimator is not None and estimator != 'passthrough'
|
|
and not hasattr(estimator, "fit")):
|
|
raise TypeError(
|
|
"Last step of Pipeline should implement fit "
|
|
"or be the string 'passthrough'. "
|
|
"'%s' (type %s) doesn't" % (estimator, type(estimator)))
|
|
|
|
def _iter(self, with_final=True, filter_passthrough=True):
|
|
"""
|
|
Generate (idx, (name, trans)) tuples from self.steps
|
|
|
|
When filter_passthrough is True, 'passthrough' and None transformers
|
|
are filtered out.
|
|
"""
|
|
stop = len(self.steps)
|
|
if not with_final:
|
|
stop -= 1
|
|
|
|
for idx, (name, trans) in enumerate(islice(self.steps, 0, stop)):
|
|
if not filter_passthrough:
|
|
yield idx, name, trans
|
|
elif trans is not None and trans != 'passthrough':
|
|
yield idx, name, trans
|
|
|
|
def __len__(self):
|
|
"""
|
|
Returns the length of the Pipeline
|
|
"""
|
|
return len(self.steps)
|
|
|
|
def __getitem__(self, ind):
|
|
"""Returns a sub-pipeline or a single esimtator in the pipeline
|
|
|
|
Indexing with an integer will return an estimator; using a slice
|
|
returns another Pipeline instance which copies a slice of this
|
|
Pipeline. This copy is shallow: modifying (or fitting) estimators in
|
|
the sub-pipeline will affect the larger pipeline and vice-versa.
|
|
However, replacing a value in `step` will not affect a copy.
|
|
"""
|
|
if isinstance(ind, slice):
|
|
if ind.step not in (1, None):
|
|
raise ValueError('Pipeline slicing only supports a step of 1')
|
|
return self.__class__(self.steps[ind])
|
|
try:
|
|
name, est = self.steps[ind]
|
|
except TypeError:
|
|
# Not an int, try get step by name
|
|
return self.named_steps[ind]
|
|
return est
|
|
|
|
@property
|
|
def _estimator_type(self):
|
|
return self.steps[-1][1]._estimator_type
|
|
|
|
@property
|
|
def named_steps(self):
|
|
# Use Bunch object to improve autocomplete
|
|
return Bunch(**dict(self.steps))
|
|
|
|
@property
|
|
def _final_estimator(self):
|
|
estimator = self.steps[-1][1]
|
|
return 'passthrough' if estimator is None else estimator
|
|
|
|
def _log_message(self, step_idx):
|
|
if not self.verbose:
|
|
return None
|
|
name, step = self.steps[step_idx]
|
|
|
|
return '(step %d of %d) Processing %s' % (step_idx + 1,
|
|
len(self.steps),
|
|
name)
|
|
|
|
# Estimator interface
|
|
|
|
def _fit(self, X, y=None, **fit_params):
|
|
# shallow copy of steps - this should really be steps_
|
|
self.steps = list(self.steps)
|
|
self._validate_steps()
|
|
# Setup the memory
|
|
memory = check_memory(self.memory)
|
|
|
|
fit_transform_one_cached = memory.cache(_fit_transform_one)
|
|
|
|
fit_params_steps = {name: {} for name, step in self.steps
|
|
if step is not None}
|
|
for pname, pval in fit_params.items():
|
|
if '__' not in pname:
|
|
raise ValueError(
|
|
"Pipeline.fit does not accept the {} parameter. "
|
|
"You can pass parameters to specific steps of your "
|
|
"pipeline using the stepname__parameter format, e.g. "
|
|
"`Pipeline.fit(X, y, logisticregression__sample_weight"
|
|
"=sample_weight)`.".format(pname))
|
|
step, param = pname.split('__', 1)
|
|
fit_params_steps[step][param] = pval
|
|
for (step_idx,
|
|
name,
|
|
transformer) in self._iter(with_final=False,
|
|
filter_passthrough=False):
|
|
if (transformer is None or transformer == 'passthrough'):
|
|
with _print_elapsed_time('Pipeline',
|
|
self._log_message(step_idx)):
|
|
continue
|
|
|
|
if hasattr(memory, 'location'):
|
|
# joblib >= 0.12
|
|
if memory.location is None:
|
|
# we do not clone when caching is disabled to
|
|
# preserve backward compatibility
|
|
cloned_transformer = transformer
|
|
else:
|
|
cloned_transformer = clone(transformer)
|
|
elif hasattr(memory, 'cachedir'):
|
|
# joblib < 0.11
|
|
if memory.cachedir is None:
|
|
# we do not clone when caching is disabled to
|
|
# preserve backward compatibility
|
|
cloned_transformer = transformer
|
|
else:
|
|
cloned_transformer = clone(transformer)
|
|
else:
|
|
cloned_transformer = clone(transformer)
|
|
# Fit or load from cache the current transformer
|
|
X, fitted_transformer = fit_transform_one_cached(
|
|
cloned_transformer, X, y, None,
|
|
message_clsname='Pipeline',
|
|
message=self._log_message(step_idx),
|
|
**fit_params_steps[name])
|
|
# Replace the transformer of the step with the fitted
|
|
# transformer. This is necessary when loading the transformer
|
|
# from the cache.
|
|
self.steps[step_idx] = (name, fitted_transformer)
|
|
if self._final_estimator == 'passthrough':
|
|
return X, {}
|
|
return X, fit_params_steps[self.steps[-1][0]]
|
|
|
|
def fit(self, X, y=None, **fit_params):
|
|
"""Fit the model
|
|
|
|
Fit all the transforms one after the other and transform the
|
|
data, then fit the transformed data using the final estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Training data. Must fulfill input requirements of first step of the
|
|
pipeline.
|
|
|
|
y : iterable, default=None
|
|
Training targets. Must fulfill label requirements for all steps of
|
|
the pipeline.
|
|
|
|
**fit_params : dict of string -> object
|
|
Parameters passed to the ``fit`` method of each step, where
|
|
each parameter name is prefixed such that parameter ``p`` for step
|
|
``s`` has key ``s__p``.
|
|
|
|
Returns
|
|
-------
|
|
self : Pipeline
|
|
This estimator
|
|
"""
|
|
Xt, fit_params = self._fit(X, y, **fit_params)
|
|
with _print_elapsed_time('Pipeline',
|
|
self._log_message(len(self.steps) - 1)):
|
|
if self._final_estimator != 'passthrough':
|
|
self._final_estimator.fit(Xt, y, **fit_params)
|
|
return self
|
|
|
|
def fit_transform(self, X, y=None, **fit_params):
|
|
"""Fit the model and transform with the final estimator
|
|
|
|
Fits all the transforms one after the other and transforms the
|
|
data, then uses fit_transform on transformed data with the final
|
|
estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Training data. Must fulfill input requirements of first step of the
|
|
pipeline.
|
|
|
|
y : iterable, default=None
|
|
Training targets. Must fulfill label requirements for all steps of
|
|
the pipeline.
|
|
|
|
**fit_params : dict of string -> object
|
|
Parameters passed to the ``fit`` method of each step, where
|
|
each parameter name is prefixed such that parameter ``p`` for step
|
|
``s`` has key ``s__p``.
|
|
|
|
Returns
|
|
-------
|
|
Xt : array-like of shape (n_samples, n_transformed_features)
|
|
Transformed samples
|
|
"""
|
|
last_step = self._final_estimator
|
|
Xt, fit_params = self._fit(X, y, **fit_params)
|
|
with _print_elapsed_time('Pipeline',
|
|
self._log_message(len(self.steps) - 1)):
|
|
if last_step == 'passthrough':
|
|
return Xt
|
|
if hasattr(last_step, 'fit_transform'):
|
|
return last_step.fit_transform(Xt, y, **fit_params)
|
|
else:
|
|
return last_step.fit(Xt, y, **fit_params).transform(Xt)
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def predict(self, X, **predict_params):
|
|
"""Apply transforms to the data, and predict with the final estimator
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to predict on. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
**predict_params : dict of string -> object
|
|
Parameters to the ``predict`` called at the end of all
|
|
transformations in the pipeline. Note that while this may be
|
|
used to return uncertainties from some models with return_std
|
|
or return_cov, uncertainties that are generated by the
|
|
transformations in the pipeline are not propagated to the
|
|
final estimator.
|
|
|
|
Returns
|
|
-------
|
|
y_pred : array-like
|
|
"""
|
|
Xt = X
|
|
for _, name, transform in self._iter(with_final=False):
|
|
Xt = transform.transform(Xt)
|
|
return self.steps[-1][-1].predict(Xt, **predict_params)
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def fit_predict(self, X, y=None, **fit_params):
|
|
"""Applies fit_predict of last step in pipeline after transforms.
|
|
|
|
Applies fit_transforms of a pipeline to the data, followed by the
|
|
fit_predict method of the final estimator in the pipeline. Valid
|
|
only if the final estimator implements fit_predict.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Training data. Must fulfill input requirements of first step of
|
|
the pipeline.
|
|
|
|
y : iterable, default=None
|
|
Training targets. Must fulfill label requirements for all steps
|
|
of the pipeline.
|
|
|
|
**fit_params : dict of string -> object
|
|
Parameters passed to the ``fit`` method of each step, where
|
|
each parameter name is prefixed such that parameter ``p`` for step
|
|
``s`` has key ``s__p``.
|
|
|
|
Returns
|
|
-------
|
|
y_pred : array-like
|
|
"""
|
|
Xt, fit_params = self._fit(X, y, **fit_params)
|
|
with _print_elapsed_time('Pipeline',
|
|
self._log_message(len(self.steps) - 1)):
|
|
y_pred = self.steps[-1][-1].fit_predict(Xt, y, **fit_params)
|
|
return y_pred
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def predict_proba(self, X):
|
|
"""Apply transforms, and predict_proba of the final estimator
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to predict on. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
Returns
|
|
-------
|
|
y_proba : array-like of shape (n_samples, n_classes)
|
|
"""
|
|
Xt = X
|
|
for _, name, transform in self._iter(with_final=False):
|
|
Xt = transform.transform(Xt)
|
|
return self.steps[-1][-1].predict_proba(Xt)
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def decision_function(self, X):
|
|
"""Apply transforms, and decision_function of the final estimator
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to predict on. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
Returns
|
|
-------
|
|
y_score : array-like of shape (n_samples, n_classes)
|
|
"""
|
|
Xt = X
|
|
for _, name, transform in self._iter(with_final=False):
|
|
Xt = transform.transform(Xt)
|
|
return self.steps[-1][-1].decision_function(Xt)
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def score_samples(self, X):
|
|
"""Apply transforms, and score_samples of the final estimator.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to predict on. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
Returns
|
|
-------
|
|
y_score : ndarray, shape (n_samples,)
|
|
"""
|
|
Xt = X
|
|
for _, _, transformer in self._iter(with_final=False):
|
|
Xt = transformer.transform(Xt)
|
|
return self.steps[-1][-1].score_samples(Xt)
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def predict_log_proba(self, X):
|
|
"""Apply transforms, and predict_log_proba of the final estimator
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to predict on. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
Returns
|
|
-------
|
|
y_score : array-like of shape (n_samples, n_classes)
|
|
"""
|
|
Xt = X
|
|
for _, name, transform in self._iter(with_final=False):
|
|
Xt = transform.transform(Xt)
|
|
return self.steps[-1][-1].predict_log_proba(Xt)
|
|
|
|
@property
|
|
def transform(self):
|
|
"""Apply transforms, and transform with the final estimator
|
|
|
|
This also works where final estimator is ``None``: all prior
|
|
transformations are applied.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to transform. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
Returns
|
|
-------
|
|
Xt : array-like of shape (n_samples, n_transformed_features)
|
|
"""
|
|
# _final_estimator is None or has transform, otherwise attribute error
|
|
# XXX: Handling the None case means we can't use if_delegate_has_method
|
|
if self._final_estimator != 'passthrough':
|
|
self._final_estimator.transform
|
|
return self._transform
|
|
|
|
def _transform(self, X):
|
|
Xt = X
|
|
for _, _, transform in self._iter():
|
|
Xt = transform.transform(Xt)
|
|
return Xt
|
|
|
|
@property
|
|
def inverse_transform(self):
|
|
"""Apply inverse transformations in reverse order
|
|
|
|
All estimators in the pipeline must support ``inverse_transform``.
|
|
|
|
Parameters
|
|
----------
|
|
Xt : array-like of shape (n_samples, n_transformed_features)
|
|
Data samples, where ``n_samples`` is the number of samples and
|
|
``n_features`` is the number of features. Must fulfill
|
|
input requirements of last step of pipeline's
|
|
``inverse_transform`` method.
|
|
|
|
Returns
|
|
-------
|
|
Xt : array-like of shape (n_samples, n_features)
|
|
"""
|
|
# raise AttributeError if necessary for hasattr behaviour
|
|
# XXX: Handling the None case means we can't use if_delegate_has_method
|
|
for _, _, transform in self._iter():
|
|
transform.inverse_transform
|
|
return self._inverse_transform
|
|
|
|
def _inverse_transform(self, X):
|
|
Xt = X
|
|
reverse_iter = reversed(list(self._iter()))
|
|
for _, _, transform in reverse_iter:
|
|
Xt = transform.inverse_transform(Xt)
|
|
return Xt
|
|
|
|
@if_delegate_has_method(delegate='_final_estimator')
|
|
def score(self, X, y=None, sample_weight=None):
|
|
"""Apply transforms, and score with the final estimator
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable
|
|
Data to predict on. Must fulfill input requirements of first step
|
|
of the pipeline.
|
|
|
|
y : iterable, default=None
|
|
Targets used for scoring. Must fulfill label requirements for all
|
|
steps of the pipeline.
|
|
|
|
sample_weight : array-like, default=None
|
|
If not None, this argument is passed as ``sample_weight`` keyword
|
|
argument to the ``score`` method of the final estimator.
|
|
|
|
Returns
|
|
-------
|
|
score : float
|
|
"""
|
|
Xt = X
|
|
for _, name, transform in self._iter(with_final=False):
|
|
Xt = transform.transform(Xt)
|
|
score_params = {}
|
|
if sample_weight is not None:
|
|
score_params['sample_weight'] = sample_weight
|
|
return self.steps[-1][-1].score(Xt, y, **score_params)
|
|
|
|
@property
|
|
def classes_(self):
|
|
return self.steps[-1][-1].classes_
|
|
|
|
@property
|
|
def _pairwise(self):
|
|
# check if first estimator expects pairwise input
|
|
return getattr(self.steps[0][1], '_pairwise', False)
|
|
|
|
|
|
def _name_estimators(estimators):
|
|
"""Generate names for estimators."""
|
|
|
|
names = [
|
|
estimator
|
|
if isinstance(estimator, str) else type(estimator).__name__.lower()
|
|
for estimator in estimators
|
|
]
|
|
namecount = defaultdict(int)
|
|
for est, name in zip(estimators, names):
|
|
namecount[name] += 1
|
|
|
|
for k, v in list(namecount.items()):
|
|
if v == 1:
|
|
del namecount[k]
|
|
|
|
for i in reversed(range(len(estimators))):
|
|
name = names[i]
|
|
if name in namecount:
|
|
names[i] += "-%d" % namecount[name]
|
|
namecount[name] -= 1
|
|
|
|
return list(zip(names, estimators))
|
|
|
|
|
|
def make_pipeline(*steps, **kwargs):
|
|
"""Construct a Pipeline from the given estimators.
|
|
|
|
This is a shorthand for the Pipeline constructor; it does not require, and
|
|
does not permit, naming the estimators. Instead, their names will be set
|
|
to the lowercase of their types automatically.
|
|
|
|
Parameters
|
|
----------
|
|
*steps : list of estimators.
|
|
|
|
memory : None, str or object with the joblib.Memory interface, optional
|
|
Used to cache the fitted transformers of the pipeline. By default,
|
|
no caching is performed. If a string is given, it is the path to
|
|
the caching directory. Enabling caching triggers a clone of
|
|
the transformers before fitting. Therefore, the transformer
|
|
instance given to the pipeline cannot be inspected
|
|
directly. Use the attribute ``named_steps`` or ``steps`` to
|
|
inspect estimators within the pipeline. Caching the
|
|
transformers is advantageous when fitting is time consuming.
|
|
|
|
verbose : boolean, default=False
|
|
If True, the time elapsed while fitting each step will be printed as it
|
|
is completed.
|
|
|
|
See Also
|
|
--------
|
|
sklearn.pipeline.Pipeline : Class for creating a pipeline of
|
|
transforms with a final estimator.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn.naive_bayes import GaussianNB
|
|
>>> from sklearn.preprocessing import StandardScaler
|
|
>>> make_pipeline(StandardScaler(), GaussianNB(priors=None))
|
|
Pipeline(steps=[('standardscaler', StandardScaler()),
|
|
('gaussiannb', GaussianNB())])
|
|
|
|
Returns
|
|
-------
|
|
p : Pipeline
|
|
"""
|
|
memory = kwargs.pop('memory', None)
|
|
verbose = kwargs.pop('verbose', False)
|
|
if kwargs:
|
|
raise TypeError('Unknown keyword arguments: "{}"'
|
|
.format(list(kwargs.keys())[0]))
|
|
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)
|
|
|
|
|
|
def _transform_one(transformer, X, y, weight, **fit_params):
|
|
res = transformer.transform(X)
|
|
# if we have a weight for this transformer, multiply output
|
|
if weight is None:
|
|
return res
|
|
return res * weight
|
|
|
|
|
|
def _fit_transform_one(transformer,
|
|
X,
|
|
y,
|
|
weight,
|
|
message_clsname='',
|
|
message=None,
|
|
**fit_params):
|
|
"""
|
|
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
|
|
with the fitted transformer. If ``weight`` is not ``None``, the result will
|
|
be multiplied by ``weight``.
|
|
"""
|
|
with _print_elapsed_time(message_clsname, message):
|
|
if hasattr(transformer, 'fit_transform'):
|
|
res = transformer.fit_transform(X, y, **fit_params)
|
|
else:
|
|
res = transformer.fit(X, y, **fit_params).transform(X)
|
|
|
|
if weight is None:
|
|
return res, transformer
|
|
return res * weight, transformer
|
|
|
|
|
|
def _fit_one(transformer,
|
|
X,
|
|
y,
|
|
weight,
|
|
message_clsname='',
|
|
message=None,
|
|
**fit_params):
|
|
"""
|
|
Fits ``transformer`` to ``X`` and ``y``.
|
|
"""
|
|
with _print_elapsed_time(message_clsname, message):
|
|
return transformer.fit(X, y, **fit_params)
|
|
|
|
|
|
class FeatureUnion(TransformerMixin, _BaseComposition):
|
|
"""Concatenates results of multiple transformer objects.
|
|
|
|
This estimator applies a list of transformer objects in parallel to the
|
|
input data, then concatenates the results. This is useful to combine
|
|
several feature extraction mechanisms into a single transformer.
|
|
|
|
Parameters of the transformers may be set using its name and the parameter
|
|
name separated by a '__'. A transformer may be replaced entirely by
|
|
setting the parameter with its name to another transformer,
|
|
or removed by setting to 'drop'.
|
|
|
|
Read more in the :ref:`User Guide <feature_union>`.
|
|
|
|
.. versionadded:: 0.13
|
|
|
|
Parameters
|
|
----------
|
|
transformer_list : list of (string, transformer) tuples
|
|
List of transformer objects to be applied to the data. The first
|
|
half of each tuple is the name of the transformer.
|
|
|
|
.. versionchanged:: 0.22
|
|
Deprecated `None` as a transformer in favor of 'drop'.
|
|
|
|
n_jobs : int or None, optional (default=None)
|
|
Number of jobs to run in parallel.
|
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
|
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
|
|
for more details.
|
|
|
|
transformer_weights : dict, optional
|
|
Multiplicative weights for features per transformer.
|
|
Keys are transformer names, values the weights.
|
|
|
|
verbose : boolean, optional(default=False)
|
|
If True, the time elapsed while fitting each transformer will be
|
|
printed as it is completed.
|
|
|
|
See Also
|
|
--------
|
|
sklearn.pipeline.make_union : Convenience function for simplified
|
|
feature union construction.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn.pipeline import FeatureUnion
|
|
>>> from sklearn.decomposition import PCA, TruncatedSVD
|
|
>>> union = FeatureUnion([("pca", PCA(n_components=1)),
|
|
... ("svd", TruncatedSVD(n_components=2))])
|
|
>>> X = [[0., 1., 3], [2., 2., 5]]
|
|
>>> union.fit_transform(X)
|
|
array([[ 1.5 , 3.0..., 0.8...],
|
|
[-1.5 , 5.7..., -0.4...]])
|
|
"""
|
|
_required_parameters = ["transformer_list"]
|
|
|
|
def __init__(self, transformer_list, n_jobs=None,
|
|
transformer_weights=None, verbose=False):
|
|
self.transformer_list = transformer_list
|
|
self.n_jobs = n_jobs
|
|
self.transformer_weights = transformer_weights
|
|
self.verbose = verbose
|
|
self._validate_transformers()
|
|
|
|
def get_params(self, deep=True):
|
|
"""Get parameters for this estimator.
|
|
|
|
Parameters
|
|
----------
|
|
deep : boolean, optional
|
|
If True, will return the parameters for this estimator and
|
|
contained subobjects that are estimators.
|
|
|
|
Returns
|
|
-------
|
|
params : mapping of string to any
|
|
Parameter names mapped to their values.
|
|
"""
|
|
return self._get_params('transformer_list', deep=deep)
|
|
|
|
def set_params(self, **kwargs):
|
|
"""Set the parameters of this estimator.
|
|
|
|
Valid parameter keys can be listed with ``get_params()``.
|
|
|
|
Returns
|
|
-------
|
|
self
|
|
"""
|
|
self._set_params('transformer_list', **kwargs)
|
|
return self
|
|
|
|
def _validate_transformers(self):
|
|
names, transformers = zip(*self.transformer_list)
|
|
|
|
# validate names
|
|
self._validate_names(names)
|
|
|
|
# validate estimators
|
|
for t in transformers:
|
|
# TODO: Remove in 0.24 when None is removed
|
|
if t is None:
|
|
warnings.warn("Using None as a transformer is deprecated "
|
|
"in version 0.22 and will be removed in "
|
|
"version 0.24. Please use 'drop' instead.",
|
|
FutureWarning)
|
|
continue
|
|
if t == 'drop':
|
|
continue
|
|
if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not
|
|
hasattr(t, "transform")):
|
|
raise TypeError("All estimators should implement fit and "
|
|
"transform. '%s' (type %s) doesn't" %
|
|
(t, type(t)))
|
|
|
|
def _iter(self):
|
|
"""
|
|
Generate (name, trans, weight) tuples excluding None and
|
|
'drop' transformers.
|
|
"""
|
|
get_weight = (self.transformer_weights or {}).get
|
|
return ((name, trans, get_weight(name))
|
|
for name, trans in self.transformer_list
|
|
if trans is not None and trans != 'drop')
|
|
|
|
def get_feature_names(self):
|
|
"""Get feature names from all transformers.
|
|
|
|
Returns
|
|
-------
|
|
feature_names : list of strings
|
|
Names of the features produced by transform.
|
|
"""
|
|
feature_names = []
|
|
for name, trans, weight in self._iter():
|
|
if not hasattr(trans, 'get_feature_names'):
|
|
raise AttributeError("Transformer %s (type %s) does not "
|
|
"provide get_feature_names."
|
|
% (str(name), type(trans).__name__))
|
|
feature_names.extend([name + "__" + f for f in
|
|
trans.get_feature_names()])
|
|
return feature_names
|
|
|
|
def fit(self, X, y=None, **fit_params):
|
|
"""Fit all transformers using X.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable or array-like, depending on transformers
|
|
Input data, used to fit transformers.
|
|
|
|
y : array-like, shape (n_samples, ...), optional
|
|
Targets for supervised learning.
|
|
|
|
Returns
|
|
-------
|
|
self : FeatureUnion
|
|
This estimator
|
|
"""
|
|
transformers = self._parallel_func(X, y, fit_params, _fit_one)
|
|
if not transformers:
|
|
# All transformers are None
|
|
return self
|
|
|
|
self._update_transformer_list(transformers)
|
|
return self
|
|
|
|
def fit_transform(self, X, y=None, **fit_params):
|
|
"""Fit all transformers, transform the data and concatenate results.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable or array-like, depending on transformers
|
|
Input data to be transformed.
|
|
|
|
y : array-like, shape (n_samples, ...), optional
|
|
Targets for supervised learning.
|
|
|
|
Returns
|
|
-------
|
|
X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
|
|
hstack of results of transformers. sum_n_components is the
|
|
sum of n_components (output dimension) over transformers.
|
|
"""
|
|
results = self._parallel_func(X, y, fit_params, _fit_transform_one)
|
|
if not results:
|
|
# All transformers are None
|
|
return np.zeros((X.shape[0], 0))
|
|
|
|
Xs, transformers = zip(*results)
|
|
self._update_transformer_list(transformers)
|
|
|
|
if any(sparse.issparse(f) for f in Xs):
|
|
Xs = sparse.hstack(Xs).tocsr()
|
|
else:
|
|
Xs = np.hstack(Xs)
|
|
return Xs
|
|
|
|
def _log_message(self, name, idx, total):
|
|
if not self.verbose:
|
|
return None
|
|
return '(step %d of %d) Processing %s' % (idx, total, name)
|
|
|
|
def _parallel_func(self, X, y, fit_params, func):
|
|
"""Runs func in parallel on X and y"""
|
|
self.transformer_list = list(self.transformer_list)
|
|
self._validate_transformers()
|
|
transformers = list(self._iter())
|
|
|
|
return Parallel(n_jobs=self.n_jobs)(delayed(func)(
|
|
transformer, X, y, weight,
|
|
message_clsname='FeatureUnion',
|
|
message=self._log_message(name, idx, len(transformers)),
|
|
**fit_params) for idx, (name, transformer,
|
|
weight) in enumerate(transformers, 1))
|
|
|
|
def transform(self, X):
|
|
"""Transform X separately by each transformer, concatenate results.
|
|
|
|
Parameters
|
|
----------
|
|
X : iterable or array-like, depending on transformers
|
|
Input data to be transformed.
|
|
|
|
Returns
|
|
-------
|
|
X_t : array-like or sparse matrix, shape (n_samples, sum_n_components)
|
|
hstack of results of transformers. sum_n_components is the
|
|
sum of n_components (output dimension) over transformers.
|
|
"""
|
|
Xs = Parallel(n_jobs=self.n_jobs)(
|
|
delayed(_transform_one)(trans, X, None, weight)
|
|
for name, trans, weight in self._iter())
|
|
if not Xs:
|
|
# All transformers are None
|
|
return np.zeros((X.shape[0], 0))
|
|
if any(sparse.issparse(f) for f in Xs):
|
|
Xs = sparse.hstack(Xs).tocsr()
|
|
else:
|
|
Xs = np.hstack(Xs)
|
|
return Xs
|
|
|
|
def _update_transformer_list(self, transformers):
|
|
transformers = iter(transformers)
|
|
self.transformer_list[:] = [(name, old if old is None or old == 'drop'
|
|
else next(transformers))
|
|
for name, old in self.transformer_list]
|
|
|
|
|
|
def make_union(*transformers, **kwargs):
|
|
"""
|
|
Construct a FeatureUnion from the given transformers.
|
|
|
|
This is a shorthand for the FeatureUnion constructor; it does not require,
|
|
and does not permit, naming the transformers. Instead, they will be given
|
|
names automatically based on their types. It also does not allow weighting.
|
|
|
|
Parameters
|
|
----------
|
|
*transformers : list of estimators
|
|
|
|
n_jobs : int or None, optional (default=None)
|
|
Number of jobs to run in parallel.
|
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
|
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
|
|
for more details.
|
|
|
|
verbose : boolean, optional(default=False)
|
|
If True, the time elapsed while fitting each transformer will be
|
|
printed as it is completed.
|
|
|
|
Returns
|
|
-------
|
|
f : FeatureUnion
|
|
|
|
See Also
|
|
--------
|
|
sklearn.pipeline.FeatureUnion : Class for concatenating the results
|
|
of multiple transformer objects.
|
|
|
|
Examples
|
|
--------
|
|
>>> from sklearn.decomposition import PCA, TruncatedSVD
|
|
>>> from sklearn.pipeline import make_union
|
|
>>> make_union(PCA(), TruncatedSVD())
|
|
FeatureUnion(transformer_list=[('pca', PCA()),
|
|
('truncatedsvd', TruncatedSVD())])
|
|
"""
|
|
n_jobs = kwargs.pop('n_jobs', None)
|
|
verbose = kwargs.pop('verbose', False)
|
|
if kwargs:
|
|
# We do not currently support `transformer_weights` as we may want to
|
|
# change its type spec in make_union
|
|
raise TypeError('Unknown keyword arguments: "{}"'
|
|
.format(list(kwargs.keys())[0]))
|
|
return FeatureUnion(
|
|
_name_estimators(transformers), n_jobs=n_jobs, verbose=verbose)
|