3RNN/Lib/site-packages/sklearn/neighbors/_nearest_centroid.py
2024-05-26 19:49:15 +02:00

217 lines
7.6 KiB
Python

"""
Nearest Centroid Classification
"""
# Author: Robert Layton <robertlayton@gmail.com>
# Olivier Grisel <olivier.grisel@ensta.org>
#
# License: BSD 3 clause
from numbers import Real
import numpy as np
from scipy import sparse as sp
from ..base import BaseEstimator, ClassifierMixin, _fit_context
from ..metrics.pairwise import pairwise_distances_argmin
from ..preprocessing import LabelEncoder
from ..utils._param_validation import Interval, StrOptions
from ..utils.multiclass import check_classification_targets
from ..utils.sparsefuncs import csc_median_axis_0
from ..utils.validation import check_is_fitted
class NearestCentroid(ClassifierMixin, BaseEstimator):
"""Nearest centroid classifier.
Each class is represented by its centroid, with test samples classified to
the class with the nearest centroid.
Read more in the :ref:`User Guide <nearest_centroid_classifier>`.
Parameters
----------
metric : {"euclidean", "manhattan"}, default="euclidean"
Metric to use for distance computation.
If `metric="euclidean"`, the centroid for the samples corresponding to each
class is the arithmetic mean, which minimizes the sum of squared L1 distances.
If `metric="manhattan"`, the centroid is the feature-wise median, which
minimizes the sum of L1 distances.
.. versionchanged:: 1.5
All metrics but `"euclidean"` and `"manhattan"` were deprecated and
now raise an error.
.. versionchanged:: 0.19
`metric='precomputed'` was deprecated and now raises an error
shrink_threshold : float, default=None
Threshold for shrinking centroids to remove features.
Attributes
----------
centroids_ : array-like of shape (n_classes, n_features)
Centroid of each class.
classes_ : array of shape (n_classes,)
The unique classes labels.
n_features_in_ : int
Number of features seen during :term:`fit`.
.. versionadded:: 0.24
feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Defined only when `X`
has feature names that are all strings.
.. versionadded:: 1.0
See Also
--------
KNeighborsClassifier : Nearest neighbors classifier.
Notes
-----
When used for text classification with tf-idf vectors, this classifier is
also known as the Rocchio classifier.
References
----------
Tibshirani, R., Hastie, T., Narasimhan, B., & Chu, G. (2002). Diagnosis of
multiple cancer types by shrunken centroids of gene expression. Proceedings
of the National Academy of Sciences of the United States of America,
99(10), 6567-6572. The National Academy of Sciences.
Examples
--------
>>> from sklearn.neighbors import NearestCentroid
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> y = np.array([1, 1, 1, 2, 2, 2])
>>> clf = NearestCentroid()
>>> clf.fit(X, y)
NearestCentroid()
>>> print(clf.predict([[-0.8, -1]]))
[1]
"""
_parameter_constraints: dict = {
"metric": [StrOptions({"manhattan", "euclidean"})],
"shrink_threshold": [Interval(Real, 0, None, closed="neither"), None],
}
def __init__(self, metric="euclidean", *, shrink_threshold=None):
self.metric = metric
self.shrink_threshold = shrink_threshold
@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y):
"""
Fit the NearestCentroid model according to the given training data.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vector, where `n_samples` is the number of samples and
`n_features` is the number of features.
Note that centroid shrinking cannot be used with sparse matrices.
y : array-like of shape (n_samples,)
Target values.
Returns
-------
self : object
Fitted estimator.
"""
# If X is sparse and the metric is "manhattan", store it in a csc
# format is easier to calculate the median.
if self.metric == "manhattan":
X, y = self._validate_data(X, y, accept_sparse=["csc"])
else:
X, y = self._validate_data(X, y, accept_sparse=["csr", "csc"])
is_X_sparse = sp.issparse(X)
if is_X_sparse and self.shrink_threshold:
raise ValueError("threshold shrinking not supported for sparse input")
check_classification_targets(y)
n_samples, n_features = X.shape
le = LabelEncoder()
y_ind = le.fit_transform(y)
self.classes_ = classes = le.classes_
n_classes = classes.size
if n_classes < 2:
raise ValueError(
"The number of classes has to be greater than one; got %d class"
% (n_classes)
)
# Mask mapping each class to its members.
self.centroids_ = np.empty((n_classes, n_features), dtype=np.float64)
# Number of clusters in each class.
nk = np.zeros(n_classes)
for cur_class in range(n_classes):
center_mask = y_ind == cur_class
nk[cur_class] = np.sum(center_mask)
if is_X_sparse:
center_mask = np.where(center_mask)[0]
if self.metric == "manhattan":
# NumPy does not calculate median of sparse matrices.
if not is_X_sparse:
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
else:
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask])
else: # metric == "euclidean"
self.centroids_[cur_class] = X[center_mask].mean(axis=0)
if self.shrink_threshold:
if np.all(np.ptp(X, axis=0) == 0):
raise ValueError("All features have zero variance. Division by zero.")
dataset_centroid_ = np.mean(X, axis=0)
# m parameter for determining deviation
m = np.sqrt((1.0 / nk) - (1.0 / n_samples))
# Calculate deviation using the standard deviation of centroids.
variance = (X - self.centroids_[y_ind]) ** 2
variance = variance.sum(axis=0)
s = np.sqrt(variance / (n_samples - n_classes))
s += np.median(s) # To deter outliers from affecting the results.
mm = m.reshape(len(m), 1) # Reshape to allow broadcasting.
ms = mm * s
deviation = (self.centroids_ - dataset_centroid_) / ms
# Soft thresholding: if the deviation crosses 0 during shrinking,
# it becomes zero.
signs = np.sign(deviation)
deviation = np.abs(deviation) - self.shrink_threshold
np.clip(deviation, 0, None, out=deviation)
deviation *= signs
# Now adjust the centroids using the deviation
msd = ms * deviation
self.centroids_ = dataset_centroid_[np.newaxis, :] + msd
return self
def predict(self, X):
"""Perform classification on an array of test vectors `X`.
The predicted class `C` for each sample in `X` is returned.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Test samples.
Returns
-------
C : ndarray of shape (n_samples,)
The predicted classes.
"""
check_is_fitted(self)
X = self._validate_data(X, accept_sparse="csr", reset=False)
return self.classes_[
pairwise_distances_argmin(X, self.centroids_, metric=self.metric)
]