497 lines
18 KiB
Python
497 lines
18 KiB
Python
|
import html
|
||
|
import itertools
|
||
|
from contextlib import closing
|
||
|
from inspect import isclass
|
||
|
from io import StringIO
|
||
|
from pathlib import Path
|
||
|
from string import Template
|
||
|
|
||
|
from .. import __version__, config_context
|
||
|
from .fixes import parse_version
|
||
|
|
||
|
|
||
|
class _IDCounter:
|
||
|
"""Generate sequential ids with a prefix."""
|
||
|
|
||
|
def __init__(self, prefix):
|
||
|
self.prefix = prefix
|
||
|
self.count = 0
|
||
|
|
||
|
def get_id(self):
|
||
|
self.count += 1
|
||
|
return f"{self.prefix}-{self.count}"
|
||
|
|
||
|
|
||
|
def _get_css_style():
|
||
|
return Path(__file__).with_suffix(".css").read_text(encoding="utf-8")
|
||
|
|
||
|
|
||
|
_CONTAINER_ID_COUNTER = _IDCounter("sk-container-id")
|
||
|
_ESTIMATOR_ID_COUNTER = _IDCounter("sk-estimator-id")
|
||
|
_CSS_STYLE = _get_css_style()
|
||
|
|
||
|
|
||
|
class _VisualBlock:
|
||
|
"""HTML Representation of Estimator
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
kind : {'serial', 'parallel', 'single'}
|
||
|
kind of HTML block
|
||
|
|
||
|
estimators : list of estimators or `_VisualBlock`s or a single estimator
|
||
|
If kind != 'single', then `estimators` is a list of
|
||
|
estimators.
|
||
|
If kind == 'single', then `estimators` is a single estimator.
|
||
|
|
||
|
names : list of str, default=None
|
||
|
If kind != 'single', then `names` corresponds to estimators.
|
||
|
If kind == 'single', then `names` is a single string corresponding to
|
||
|
the single estimator.
|
||
|
|
||
|
name_details : list of str, str, or None, default=None
|
||
|
If kind != 'single', then `name_details` corresponds to `names`.
|
||
|
If kind == 'single', then `name_details` is a single string
|
||
|
corresponding to the single estimator.
|
||
|
|
||
|
dash_wrapped : bool, default=True
|
||
|
If true, wrapped HTML element will be wrapped with a dashed border.
|
||
|
Only active when kind != 'single'.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, kind, estimators, *, names=None, name_details=None, dash_wrapped=True
|
||
|
):
|
||
|
self.kind = kind
|
||
|
self.estimators = estimators
|
||
|
self.dash_wrapped = dash_wrapped
|
||
|
|
||
|
if self.kind in ("parallel", "serial"):
|
||
|
if names is None:
|
||
|
names = (None,) * len(estimators)
|
||
|
if name_details is None:
|
||
|
name_details = (None,) * len(estimators)
|
||
|
|
||
|
self.names = names
|
||
|
self.name_details = name_details
|
||
|
|
||
|
def _sk_visual_block_(self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
def _write_label_html(
|
||
|
out,
|
||
|
name,
|
||
|
name_details,
|
||
|
outer_class="sk-label-container",
|
||
|
inner_class="sk-label",
|
||
|
checked=False,
|
||
|
doc_link="",
|
||
|
is_fitted_css_class="",
|
||
|
is_fitted_icon="",
|
||
|
):
|
||
|
"""Write labeled html with or without a dropdown with named details.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
out : file-like object
|
||
|
The file to write the HTML representation to.
|
||
|
name : str
|
||
|
The label for the estimator. It corresponds either to the estimator class name
|
||
|
for a simple estimator or in the case of a `Pipeline` and `ColumnTransformer`,
|
||
|
it corresponds to the name of the step.
|
||
|
name_details : str
|
||
|
The details to show as content in the dropdown part of the toggleable label. It
|
||
|
can contain information such as non-default parameters or column information for
|
||
|
`ColumnTransformer`.
|
||
|
outer_class : {"sk-label-container", "sk-item"}, default="sk-label-container"
|
||
|
The CSS class for the outer container.
|
||
|
inner_class : {"sk-label", "sk-estimator"}, default="sk-label"
|
||
|
The CSS class for the inner container.
|
||
|
checked : bool, default=False
|
||
|
Whether the dropdown is folded or not. With a single estimator, we intend to
|
||
|
unfold the content.
|
||
|
doc_link : str, default=""
|
||
|
The link to the documentation for the estimator. If an empty string, no link is
|
||
|
added to the diagram. This can be generated for an estimator if it uses the
|
||
|
`_HTMLDocumentationLinkMixin`.
|
||
|
is_fitted_css_class : {"", "fitted"}
|
||
|
The CSS class to indicate whether or not the estimator is fitted. The
|
||
|
empty string means that the estimator is not fitted and "fitted" means that the
|
||
|
estimator is fitted.
|
||
|
is_fitted_icon : str, default=""
|
||
|
The HTML representation to show the fitted information in the diagram. An empty
|
||
|
string means that no information is shown.
|
||
|
"""
|
||
|
# we need to add some padding to the left of the label to be sure it is centered
|
||
|
padding_label = " " if is_fitted_icon else "" # add padding for the "i" char
|
||
|
|
||
|
out.write(
|
||
|
f'<div class="{outer_class}"><div'
|
||
|
f' class="{inner_class} {is_fitted_css_class} sk-toggleable">'
|
||
|
)
|
||
|
name = html.escape(name)
|
||
|
|
||
|
if name_details is not None:
|
||
|
name_details = html.escape(str(name_details))
|
||
|
label_class = (
|
||
|
f"sk-toggleable__label {is_fitted_css_class} sk-toggleable__label-arrow"
|
||
|
)
|
||
|
|
||
|
checked_str = "checked" if checked else ""
|
||
|
est_id = _ESTIMATOR_ID_COUNTER.get_id()
|
||
|
|
||
|
if doc_link:
|
||
|
doc_label = "<span>Online documentation</span>"
|
||
|
if name is not None:
|
||
|
doc_label = f"<span>Documentation for {name}</span>"
|
||
|
doc_link = (
|
||
|
f'<a class="sk-estimator-doc-link {is_fitted_css_class}"'
|
||
|
f' rel="noreferrer" target="_blank" href="{doc_link}">?{doc_label}</a>'
|
||
|
)
|
||
|
padding_label += " " # add additional padding for the "?" char
|
||
|
|
||
|
fmt_str = (
|
||
|
'<input class="sk-toggleable__control sk-hidden--visually"'
|
||
|
f' id="{est_id}" '
|
||
|
f'type="checkbox" {checked_str}><label for="{est_id}" '
|
||
|
f'class="{label_class} {is_fitted_css_class}">{padding_label}{name}'
|
||
|
f"{doc_link}{is_fitted_icon}</label><div "
|
||
|
f'class="sk-toggleable__content {is_fitted_css_class}">'
|
||
|
f"<pre>{name_details}</pre></div> "
|
||
|
)
|
||
|
out.write(fmt_str)
|
||
|
else:
|
||
|
out.write(f"<label>{name}</label>")
|
||
|
out.write("</div></div>") # outer_class inner_class
|
||
|
|
||
|
|
||
|
def _get_visual_block(estimator):
|
||
|
"""Generate information about how to display an estimator."""
|
||
|
if hasattr(estimator, "_sk_visual_block_"):
|
||
|
try:
|
||
|
return estimator._sk_visual_block_()
|
||
|
except Exception:
|
||
|
return _VisualBlock(
|
||
|
"single",
|
||
|
estimator,
|
||
|
names=estimator.__class__.__name__,
|
||
|
name_details=str(estimator),
|
||
|
)
|
||
|
|
||
|
if isinstance(estimator, str):
|
||
|
return _VisualBlock(
|
||
|
"single", estimator, names=estimator, name_details=estimator
|
||
|
)
|
||
|
elif estimator is None:
|
||
|
return _VisualBlock("single", estimator, names="None", name_details="None")
|
||
|
|
||
|
# check if estimator looks like a meta estimator (wraps estimators)
|
||
|
if hasattr(estimator, "get_params") and not isclass(estimator):
|
||
|
estimators = [
|
||
|
(key, est)
|
||
|
for key, est in estimator.get_params(deep=False).items()
|
||
|
if hasattr(est, "get_params") and hasattr(est, "fit") and not isclass(est)
|
||
|
]
|
||
|
if estimators:
|
||
|
return _VisualBlock(
|
||
|
"parallel",
|
||
|
[est for _, est in estimators],
|
||
|
names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
|
||
|
name_details=[str(est) for _, est in estimators],
|
||
|
)
|
||
|
|
||
|
return _VisualBlock(
|
||
|
"single",
|
||
|
estimator,
|
||
|
names=estimator.__class__.__name__,
|
||
|
name_details=str(estimator),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _write_estimator_html(
|
||
|
out,
|
||
|
estimator,
|
||
|
estimator_label,
|
||
|
estimator_label_details,
|
||
|
is_fitted_css_class,
|
||
|
is_fitted_icon="",
|
||
|
first_call=False,
|
||
|
):
|
||
|
"""Write estimator to html in serial, parallel, or by itself (single).
|
||
|
|
||
|
For multiple estimators, this function is called recursively.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
out : file-like object
|
||
|
The file to write the HTML representation to.
|
||
|
estimator : estimator object
|
||
|
The estimator to visualize.
|
||
|
estimator_label : str
|
||
|
The label for the estimator. It corresponds either to the estimator class name
|
||
|
for simple estimator or in the case of `Pipeline` and `ColumnTransformer`, it
|
||
|
corresponds to the name of the step.
|
||
|
estimator_label_details : str
|
||
|
The details to show as content in the dropdown part of the toggleable label.
|
||
|
It can contain information as non-default parameters or column information for
|
||
|
`ColumnTransformer`.
|
||
|
is_fitted_css_class : {"", "fitted"}
|
||
|
The CSS class to indicate whether or not the estimator is fitted or not. The
|
||
|
empty string means that the estimator is not fitted and "fitted" means that the
|
||
|
estimator is fitted.
|
||
|
is_fitted_icon : str, default=""
|
||
|
The HTML representation to show the fitted information in the diagram. An empty
|
||
|
string means that no information is shown. If the estimator to be shown is not
|
||
|
the first estimator (i.e. `first_call=False`), `is_fitted_icon` is always an
|
||
|
empty string.
|
||
|
first_call : bool, default=False
|
||
|
Whether this is the first time this function is called.
|
||
|
"""
|
||
|
if first_call:
|
||
|
est_block = _get_visual_block(estimator)
|
||
|
else:
|
||
|
is_fitted_icon = ""
|
||
|
with config_context(print_changed_only=True):
|
||
|
est_block = _get_visual_block(estimator)
|
||
|
# `estimator` can also be an instance of `_VisualBlock`
|
||
|
if hasattr(estimator, "_get_doc_link"):
|
||
|
doc_link = estimator._get_doc_link()
|
||
|
else:
|
||
|
doc_link = ""
|
||
|
if est_block.kind in ("serial", "parallel"):
|
||
|
dashed_wrapped = first_call or est_block.dash_wrapped
|
||
|
dash_cls = " sk-dashed-wrapped" if dashed_wrapped else ""
|
||
|
out.write(f'<div class="sk-item{dash_cls}">')
|
||
|
|
||
|
if estimator_label:
|
||
|
_write_label_html(
|
||
|
out,
|
||
|
estimator_label,
|
||
|
estimator_label_details,
|
||
|
doc_link=doc_link,
|
||
|
is_fitted_css_class=is_fitted_css_class,
|
||
|
is_fitted_icon=is_fitted_icon,
|
||
|
)
|
||
|
|
||
|
kind = est_block.kind
|
||
|
out.write(f'<div class="sk-{kind}">')
|
||
|
est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
|
||
|
|
||
|
for est, name, name_details in est_infos:
|
||
|
if kind == "serial":
|
||
|
_write_estimator_html(
|
||
|
out,
|
||
|
est,
|
||
|
name,
|
||
|
name_details,
|
||
|
is_fitted_css_class=is_fitted_css_class,
|
||
|
)
|
||
|
else: # parallel
|
||
|
out.write('<div class="sk-parallel-item">')
|
||
|
# wrap element in a serial visualblock
|
||
|
serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
|
||
|
_write_estimator_html(
|
||
|
out,
|
||
|
serial_block,
|
||
|
name,
|
||
|
name_details,
|
||
|
is_fitted_css_class=is_fitted_css_class,
|
||
|
)
|
||
|
out.write("</div>") # sk-parallel-item
|
||
|
|
||
|
out.write("</div></div>")
|
||
|
elif est_block.kind == "single":
|
||
|
_write_label_html(
|
||
|
out,
|
||
|
est_block.names,
|
||
|
est_block.name_details,
|
||
|
outer_class="sk-item",
|
||
|
inner_class="sk-estimator",
|
||
|
checked=first_call,
|
||
|
doc_link=doc_link,
|
||
|
is_fitted_css_class=is_fitted_css_class,
|
||
|
is_fitted_icon=is_fitted_icon,
|
||
|
)
|
||
|
|
||
|
|
||
|
def estimator_html_repr(estimator):
|
||
|
"""Build a HTML representation of an estimator.
|
||
|
|
||
|
Read more in the :ref:`User Guide <visualizing_composite_estimators>`.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
estimator : estimator object
|
||
|
The estimator to visualize.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
html: str
|
||
|
HTML representation of estimator.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from sklearn.utils._estimator_html_repr import estimator_html_repr
|
||
|
>>> from sklearn.linear_model import LogisticRegression
|
||
|
>>> estimator_html_repr(LogisticRegression())
|
||
|
'<style>...</div>'
|
||
|
"""
|
||
|
from sklearn.exceptions import NotFittedError
|
||
|
from sklearn.utils.validation import check_is_fitted
|
||
|
|
||
|
if not hasattr(estimator, "fit"):
|
||
|
status_label = "<span>Not fitted</span>"
|
||
|
is_fitted_css_class = ""
|
||
|
else:
|
||
|
try:
|
||
|
check_is_fitted(estimator)
|
||
|
status_label = "<span>Fitted</span>"
|
||
|
is_fitted_css_class = "fitted"
|
||
|
except NotFittedError:
|
||
|
status_label = "<span>Not fitted</span>"
|
||
|
is_fitted_css_class = ""
|
||
|
|
||
|
is_fitted_icon = (
|
||
|
f'<span class="sk-estimator-doc-link {is_fitted_css_class}">'
|
||
|
f"i{status_label}</span>"
|
||
|
)
|
||
|
with closing(StringIO()) as out:
|
||
|
container_id = _CONTAINER_ID_COUNTER.get_id()
|
||
|
style_template = Template(_CSS_STYLE)
|
||
|
style_with_id = style_template.substitute(id=container_id)
|
||
|
estimator_str = str(estimator)
|
||
|
|
||
|
# The fallback message is shown by default and loading the CSS sets
|
||
|
# div.sk-text-repr-fallback to display: none to hide the fallback message.
|
||
|
#
|
||
|
# If the notebook is trusted, the CSS is loaded which hides the fallback
|
||
|
# message. If the notebook is not trusted, then the CSS is not loaded and the
|
||
|
# fallback message is shown by default.
|
||
|
#
|
||
|
# The reverse logic applies to HTML repr div.sk-container.
|
||
|
# div.sk-container is hidden by default and the loading the CSS displays it.
|
||
|
fallback_msg = (
|
||
|
"In a Jupyter environment, please rerun this cell to show the HTML"
|
||
|
" representation or trust the notebook. <br />On GitHub, the"
|
||
|
" HTML representation is unable to render, please try loading this page"
|
||
|
" with nbviewer.org."
|
||
|
)
|
||
|
html_template = (
|
||
|
f"<style>{style_with_id}</style>"
|
||
|
f'<div id="{container_id}" class="sk-top-container">'
|
||
|
'<div class="sk-text-repr-fallback">'
|
||
|
f"<pre>{html.escape(estimator_str)}</pre><b>{fallback_msg}</b>"
|
||
|
"</div>"
|
||
|
'<div class="sk-container" hidden>'
|
||
|
)
|
||
|
|
||
|
out.write(html_template)
|
||
|
|
||
|
_write_estimator_html(
|
||
|
out,
|
||
|
estimator,
|
||
|
estimator.__class__.__name__,
|
||
|
estimator_str,
|
||
|
first_call=True,
|
||
|
is_fitted_css_class=is_fitted_css_class,
|
||
|
is_fitted_icon=is_fitted_icon,
|
||
|
)
|
||
|
out.write("</div></div>")
|
||
|
|
||
|
html_output = out.getvalue()
|
||
|
return html_output
|
||
|
|
||
|
|
||
|
class _HTMLDocumentationLinkMixin:
|
||
|
"""Mixin class allowing to generate a link to the API documentation.
|
||
|
|
||
|
This mixin relies on three attributes:
|
||
|
- `_doc_link_module`: it corresponds to the root module (e.g. `sklearn`). Using this
|
||
|
mixin, the default value is `sklearn`.
|
||
|
- `_doc_link_template`: it corresponds to the template used to generate the
|
||
|
link to the API documentation. Using this mixin, the default value is
|
||
|
`"https://scikit-learn.org/{version_url}/modules/generated/
|
||
|
{estimator_module}.{estimator_name}.html"`.
|
||
|
- `_doc_link_url_param_generator`: it corresponds to a function that generates the
|
||
|
parameters to be used in the template when the estimator module and name are not
|
||
|
sufficient.
|
||
|
|
||
|
The method :meth:`_get_doc_link` generates the link to the API documentation for a
|
||
|
given estimator.
|
||
|
|
||
|
This useful provides all the necessary states for
|
||
|
:func:`sklearn.utils.estimator_html_repr` to generate a link to the API
|
||
|
documentation for the estimator HTML diagram.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
If the default values for `_doc_link_module`, `_doc_link_template` are not suitable,
|
||
|
then you can override them:
|
||
|
>>> from sklearn.base import BaseEstimator
|
||
|
>>> estimator = BaseEstimator()
|
||
|
>>> estimator._doc_link_template = "https://website.com/{single_param}.html"
|
||
|
>>> def url_param_generator(estimator):
|
||
|
... return {"single_param": estimator.__class__.__name__}
|
||
|
>>> estimator._doc_link_url_param_generator = url_param_generator
|
||
|
>>> estimator._get_doc_link()
|
||
|
'https://website.com/BaseEstimator.html'
|
||
|
"""
|
||
|
|
||
|
_doc_link_module = "sklearn"
|
||
|
_doc_link_url_param_generator = None
|
||
|
|
||
|
@property
|
||
|
def _doc_link_template(self):
|
||
|
sklearn_version = parse_version(__version__)
|
||
|
if sklearn_version.dev is None:
|
||
|
version_url = f"{sklearn_version.major}.{sklearn_version.minor}"
|
||
|
else:
|
||
|
version_url = "dev"
|
||
|
return getattr(
|
||
|
self,
|
||
|
"__doc_link_template",
|
||
|
(
|
||
|
f"https://scikit-learn.org/{version_url}/modules/generated/"
|
||
|
"{estimator_module}.{estimator_name}.html"
|
||
|
),
|
||
|
)
|
||
|
|
||
|
@_doc_link_template.setter
|
||
|
def _doc_link_template(self, value):
|
||
|
setattr(self, "__doc_link_template", value)
|
||
|
|
||
|
def _get_doc_link(self):
|
||
|
"""Generates a link to the API documentation for a given estimator.
|
||
|
|
||
|
This method generates the link to the estimator's documentation page
|
||
|
by using the template defined by the attribute `_doc_link_template`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
url : str
|
||
|
The URL to the API documentation for this estimator. If the estimator does
|
||
|
not belong to module `_doc_link_module`, the empty string (i.e. `""`) is
|
||
|
returned.
|
||
|
"""
|
||
|
if self.__class__.__module__.split(".")[0] != self._doc_link_module:
|
||
|
return ""
|
||
|
|
||
|
if self._doc_link_url_param_generator is None:
|
||
|
estimator_name = self.__class__.__name__
|
||
|
# Construct the estimator's module name, up to the first private submodule.
|
||
|
# This works because in scikit-learn all public estimators are exposed at
|
||
|
# that level, even if they actually live in a private sub-module.
|
||
|
estimator_module = ".".join(
|
||
|
itertools.takewhile(
|
||
|
lambda part: not part.startswith("_"),
|
||
|
self.__class__.__module__.split("."),
|
||
|
)
|
||
|
)
|
||
|
return self._doc_link_template.format(
|
||
|
estimator_module=estimator_module, estimator_name=estimator_name
|
||
|
)
|
||
|
return self._doc_link_template.format(
|
||
|
**self._doc_link_url_param_generator(self)
|
||
|
)
|