Traktor/myenv/Lib/site-packages/sklearn/inspection/_pd_utils.py

65 lines
2.1 KiB
Python
Raw Normal View History

2024-05-26 05:12:46 +02:00
def _check_feature_names(X, feature_names=None):
"""Check feature names.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Input data.
feature_names : None or array-like of shape (n_names,), dtype=str
Feature names to check or `None`.
Returns
-------
feature_names : list of str
Feature names validated. If `feature_names` is `None`, then a list of
feature names is provided, i.e. the column names of a pandas dataframe
or a generic list of feature names (e.g. `["x0", "x1", ...]`) for a
NumPy array.
"""
if feature_names is None:
if hasattr(X, "columns") and hasattr(X.columns, "tolist"):
# get the column names for a pandas dataframe
feature_names = X.columns.tolist()
else:
# define a list of numbered indices for a numpy array
feature_names = [f"x{i}" for i in range(X.shape[1])]
elif hasattr(feature_names, "tolist"):
# convert numpy array or pandas index to a list
feature_names = feature_names.tolist()
if len(set(feature_names)) != len(feature_names):
raise ValueError("feature_names should not contain duplicates.")
return feature_names
def _get_feature_index(fx, feature_names=None):
"""Get feature index.
Parameters
----------
fx : int or str
Feature index or name.
feature_names : list of str, default=None
All feature names from which to search the indices.
Returns
-------
idx : int
Feature index.
"""
if isinstance(fx, str):
if feature_names is None:
raise ValueError(
f"Cannot plot partial dependence for feature {fx!r} since "
"the list of feature names was not provided, neither as "
"column names of a pandas data-frame nor via the feature_names "
"parameter."
)
try:
return feature_names.index(fx)
except ValueError as e:
raise ValueError(f"Feature {fx!r} not in feature_names") from e
return fx