166 lines
5.7 KiB
Python
166 lines
5.7 KiB
Python
"""
|
|
The :mod:`sklearn.utils.metaestimators` module includes utilities for meta-estimators.
|
|
"""
|
|
|
|
# Author: Joel Nothman
|
|
# Andreas Mueller
|
|
# License: BSD
|
|
from abc import ABCMeta, abstractmethod
|
|
from contextlib import suppress
|
|
from typing import Any, List
|
|
|
|
import numpy as np
|
|
|
|
from ..base import BaseEstimator
|
|
from ..utils import _safe_indexing
|
|
from ..utils._tags import _safe_tags
|
|
from ._available_if import available_if
|
|
|
|
__all__ = ["available_if"]
|
|
|
|
|
|
class _BaseComposition(BaseEstimator, metaclass=ABCMeta):
|
|
"""Handles parameter management for classifiers composed of named estimators."""
|
|
|
|
steps: List[Any]
|
|
|
|
@abstractmethod
|
|
def __init__(self):
|
|
pass
|
|
|
|
def _get_params(self, attr, deep=True):
|
|
out = super().get_params(deep=deep)
|
|
if not deep:
|
|
return out
|
|
|
|
estimators = getattr(self, attr)
|
|
try:
|
|
out.update(estimators)
|
|
except (TypeError, ValueError):
|
|
# Ignore TypeError for cases where estimators is not a list of
|
|
# (name, estimator) and ignore ValueError when the list is not
|
|
# formatted correctly. This is to prevent errors when calling
|
|
# `set_params`. `BaseEstimator.set_params` calls `get_params` which
|
|
# can error for invalid values for `estimators`.
|
|
return out
|
|
|
|
for name, estimator in estimators:
|
|
if hasattr(estimator, "get_params"):
|
|
for key, value in estimator.get_params(deep=True).items():
|
|
out["%s__%s" % (name, key)] = value
|
|
return out
|
|
|
|
def _set_params(self, attr, **params):
|
|
# Ensure strict ordering of parameter setting:
|
|
# 1. All steps
|
|
if attr in params:
|
|
setattr(self, attr, params.pop(attr))
|
|
# 2. Replace items with estimators in params
|
|
items = getattr(self, attr)
|
|
if isinstance(items, list) and items:
|
|
# Get item names used to identify valid names in params
|
|
# `zip` raises a TypeError when `items` does not contains
|
|
# elements of length 2
|
|
with suppress(TypeError):
|
|
item_names, _ = zip(*items)
|
|
for name in list(params.keys()):
|
|
if "__" not in name and name in item_names:
|
|
self._replace_estimator(attr, name, params.pop(name))
|
|
|
|
# 3. Step parameters and other initialisation arguments
|
|
super().set_params(**params)
|
|
return self
|
|
|
|
def _replace_estimator(self, attr, name, new_val):
|
|
# assumes `name` is a valid estimator name
|
|
new_estimators = list(getattr(self, attr))
|
|
for i, (estimator_name, _) in enumerate(new_estimators):
|
|
if estimator_name == name:
|
|
new_estimators[i] = (name, new_val)
|
|
break
|
|
setattr(self, attr, new_estimators)
|
|
|
|
def _validate_names(self, names):
|
|
if len(set(names)) != len(names):
|
|
raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
|
|
invalid_names = set(names).intersection(self.get_params(deep=False))
|
|
if invalid_names:
|
|
raise ValueError(
|
|
"Estimator names conflict with constructor arguments: {0!r}".format(
|
|
sorted(invalid_names)
|
|
)
|
|
)
|
|
invalid_names = [name for name in names if "__" in name]
|
|
if invalid_names:
|
|
raise ValueError(
|
|
"Estimator names must not contain __: got {0!r}".format(invalid_names)
|
|
)
|
|
|
|
|
|
def _safe_split(estimator, X, y, indices, train_indices=None):
|
|
"""Create subset of dataset and properly handle kernels.
|
|
|
|
Slice X, y according to indices for cross-validation, but take care of
|
|
precomputed kernel-matrices or pairwise affinities / distances.
|
|
|
|
If ``estimator._pairwise is True``, X needs to be square and
|
|
we slice rows and columns. If ``train_indices`` is not None,
|
|
we slice rows using ``indices`` (assumed the test set) and columns
|
|
using ``train_indices``, indicating the training set.
|
|
|
|
Labels y will always be indexed only along the first axis.
|
|
|
|
Parameters
|
|
----------
|
|
estimator : object
|
|
Estimator to determine whether we should slice only rows or rows and
|
|
columns.
|
|
|
|
X : array-like, sparse matrix or iterable
|
|
Data to be indexed. If ``estimator._pairwise is True``,
|
|
this needs to be a square array-like or sparse matrix.
|
|
|
|
y : array-like, sparse matrix or iterable
|
|
Targets to be indexed.
|
|
|
|
indices : array of int
|
|
Rows to select from X and y.
|
|
If ``estimator._pairwise is True`` and ``train_indices is None``
|
|
then ``indices`` will also be used to slice columns.
|
|
|
|
train_indices : array of int or None, default=None
|
|
If ``estimator._pairwise is True`` and ``train_indices is not None``,
|
|
then ``train_indices`` will be use to slice the columns of X.
|
|
|
|
Returns
|
|
-------
|
|
X_subset : array-like, sparse matrix or list
|
|
Indexed data.
|
|
|
|
y_subset : array-like, sparse matrix or list
|
|
Indexed targets.
|
|
|
|
"""
|
|
if _safe_tags(estimator, key="pairwise"):
|
|
if not hasattr(X, "shape"):
|
|
raise ValueError(
|
|
"Precomputed kernels or affinity matrices have "
|
|
"to be passed as arrays or sparse matrices."
|
|
)
|
|
# X is a precomputed square kernel matrix
|
|
if X.shape[0] != X.shape[1]:
|
|
raise ValueError("X should be a square kernel matrix")
|
|
if train_indices is None:
|
|
X_subset = X[np.ix_(indices, indices)]
|
|
else:
|
|
X_subset = X[np.ix_(indices, train_indices)]
|
|
else:
|
|
X_subset = _safe_indexing(X, indices)
|
|
|
|
if y is not None:
|
|
y_subset = _safe_indexing(y, indices)
|
|
else:
|
|
y_subset = None
|
|
|
|
return X_subset, y_subset
|