68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
|
import numpy as np
|
||
|
|
||
|
_DEFAULT_TAGS = {
|
||
|
'non_deterministic': False,
|
||
|
'requires_positive_X': False,
|
||
|
'requires_positive_y': False,
|
||
|
'X_types': ['2darray'],
|
||
|
'poor_score': False,
|
||
|
'no_validation': False,
|
||
|
'multioutput': False,
|
||
|
"allow_nan": False,
|
||
|
'stateless': False,
|
||
|
'multilabel': False,
|
||
|
'_skip_test': False,
|
||
|
'_xfail_checks': False,
|
||
|
'multioutput_only': False,
|
||
|
'binary_only': False,
|
||
|
'requires_fit': True,
|
||
|
'preserves_dtype': [np.float64],
|
||
|
'requires_y': False,
|
||
|
'pairwise': False,
|
||
|
}
|
||
|
|
||
|
|
||
|
def _safe_tags(estimator, key=None):
|
||
|
"""Safely get estimator tags.
|
||
|
|
||
|
:class:`~sklearn.BaseEstimator` provides the estimator tags machinery.
|
||
|
However, if an estimator does not inherit from this base class, we should
|
||
|
fall-back to the default tags.
|
||
|
|
||
|
For scikit-learn built-in estimators, we should still rely on
|
||
|
`self._get_tags()`. `_safe_tags(est)` should be used when we are not sure
|
||
|
where `est` comes from: typically `_safe_tags(self.base_estimator)` where
|
||
|
`self` is a meta-estimator, or in the common checks.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
estimator : estimator object
|
||
|
The estimator from which to get the tag.
|
||
|
|
||
|
key : str, default=None
|
||
|
Tag name to get. By default (`None`), all tags are returned.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
tags : dict or tag value
|
||
|
The estimator tags. A single value is returned if `key` is not None.
|
||
|
"""
|
||
|
if hasattr(estimator, "_get_tags"):
|
||
|
tags_provider = "_get_tags()"
|
||
|
tags = estimator._get_tags()
|
||
|
elif hasattr(estimator, "_more_tags"):
|
||
|
tags_provider = "_more_tags()"
|
||
|
tags = {**_DEFAULT_TAGS, **estimator._more_tags()}
|
||
|
else:
|
||
|
tags_provider = "_DEFAULT_TAGS"
|
||
|
tags = _DEFAULT_TAGS
|
||
|
|
||
|
if key is not None:
|
||
|
if key not in tags:
|
||
|
raise ValueError(
|
||
|
f"The key {key} is not defined in {tags_provider} for the "
|
||
|
f"class {estimator.__class__.__name__}."
|
||
|
)
|
||
|
return tags[key]
|
||
|
return tags
|