639 lines
22 KiB
Python
639 lines
22 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""The TensorBoard metrics plugin."""
|
|
|
|
|
|
import collections
|
|
import imghdr
|
|
import json
|
|
|
|
from werkzeug import wrappers
|
|
|
|
from tensorboard import errors
|
|
from tensorboard import plugin_util
|
|
from tensorboard.backend import http_util
|
|
from tensorboard.data import provider
|
|
from tensorboard.plugins import base_plugin
|
|
from tensorboard.plugins.histogram import metadata as histogram_metadata
|
|
from tensorboard.plugins.image import metadata as image_metadata
|
|
from tensorboard.plugins.metrics import metadata
|
|
from tensorboard.plugins.scalar import metadata as scalar_metadata
|
|
|
|
|
|
_IMGHDR_TO_MIMETYPE = {
|
|
"bmp": "image/bmp",
|
|
"gif": "image/gif",
|
|
"jpeg": "image/jpeg",
|
|
"png": "image/png",
|
|
"svg": "image/svg+xml",
|
|
}
|
|
|
|
_DEFAULT_IMAGE_MIMETYPE = "application/octet-stream"
|
|
|
|
_SINGLE_RUN_PLUGINS = frozenset(
|
|
[histogram_metadata.PLUGIN_NAME, image_metadata.PLUGIN_NAME]
|
|
)
|
|
|
|
_SAMPLED_PLUGINS = frozenset([image_metadata.PLUGIN_NAME])
|
|
|
|
|
|
def _get_tag_description_info(mapping):
|
|
"""Gets maps from tags to descriptions, and descriptions to runs.
|
|
|
|
Args:
|
|
mapping: a nested map `d` such that `d[run][tag]` is a time series
|
|
produced by DataProvider's `list_*` methods.
|
|
|
|
Returns:
|
|
A tuple containing
|
|
tag_to_descriptions: A map from tag strings to a set of description
|
|
strings.
|
|
description_to_runs: A map from description strings to a set of run
|
|
strings.
|
|
"""
|
|
tag_to_descriptions = collections.defaultdict(set)
|
|
description_to_runs = collections.defaultdict(set)
|
|
for (run, tag_to_content) in mapping.items():
|
|
for (tag, metadatum) in tag_to_content.items():
|
|
description = metadatum.description
|
|
if len(description):
|
|
tag_to_descriptions[tag].add(description)
|
|
description_to_runs[description].add(run)
|
|
|
|
return tag_to_descriptions, description_to_runs
|
|
|
|
|
|
def _build_combined_description(descriptions, description_to_runs):
|
|
"""Creates a single description from a set of descriptions.
|
|
|
|
Descriptions may be composites when a single tag has different descriptions
|
|
across multiple runs.
|
|
|
|
Args:
|
|
descriptions: A list of description strings.
|
|
description_to_runs: A map from description strings to a set of run
|
|
strings.
|
|
|
|
Returns:
|
|
The combined description string.
|
|
"""
|
|
prefixed_descriptions = []
|
|
for description in descriptions:
|
|
runs = sorted(description_to_runs[description])
|
|
run_or_runs = "runs" if len(runs) > 1 else "run"
|
|
run_header = "## For " + run_or_runs + ": " + ", ".join(runs)
|
|
description_html = run_header + "\n" + description
|
|
prefixed_descriptions.append(description_html)
|
|
|
|
header = "# Multiple descriptions\n"
|
|
return header + "\n".join(prefixed_descriptions)
|
|
|
|
|
|
def _get_tag_to_description(mapping):
|
|
"""Returns a map of tags to descriptions.
|
|
|
|
Args:
|
|
mapping: a nested map `d` such that `d[run][tag]` is a time series
|
|
produced by DataProvider's `list_*` methods.
|
|
|
|
Returns:
|
|
A map from tag strings to description HTML strings. E.g.
|
|
{
|
|
"loss": "<h1>Multiple descriptions</h1><h2>For runs: test, train
|
|
</h2><p>...</p>",
|
|
"loss2": "<p>The lossy details</p>",
|
|
}
|
|
"""
|
|
tag_to_descriptions, description_to_runs = _get_tag_description_info(
|
|
mapping
|
|
)
|
|
|
|
result = {}
|
|
for tag in tag_to_descriptions:
|
|
descriptions = sorted(tag_to_descriptions[tag])
|
|
if len(descriptions) == 1:
|
|
description = descriptions[0]
|
|
else:
|
|
description = _build_combined_description(
|
|
descriptions, description_to_runs
|
|
)
|
|
result[tag] = plugin_util.markdown_to_safe_html(description)
|
|
|
|
return result
|
|
|
|
|
|
def _get_run_tag_info(mapping):
|
|
"""Returns a map of run names to a list of tag names.
|
|
|
|
Args:
|
|
mapping: a nested map `d` such that `d[run][tag]` is a time series
|
|
produced by DataProvider's `list_*` methods.
|
|
|
|
Returns:
|
|
A map from run strings to a list of tag strings. E.g.
|
|
{"loss001a": ["actor/loss", "critic/loss"], ...}
|
|
"""
|
|
return {run: sorted(mapping[run]) for run in mapping}
|
|
|
|
|
|
def _format_basic_mapping(mapping):
|
|
"""Prepares a scalar or histogram mapping for client consumption.
|
|
|
|
Args:
|
|
mapping: a nested map `d` such that `d[run][tag]` is a time series
|
|
produced by DataProvider's `list_*` methods.
|
|
|
|
Returns:
|
|
A dict with the following fields:
|
|
runTagInfo: the return type of `_get_run_tag_info`
|
|
tagDescriptions: the return type of `_get_tag_to_description`
|
|
"""
|
|
return {
|
|
"runTagInfo": _get_run_tag_info(mapping),
|
|
"tagDescriptions": _get_tag_to_description(mapping),
|
|
}
|
|
|
|
|
|
def _format_image_blob_sequence_datum(sorted_datum_list, sample):
|
|
"""Formats image metadata from a list of BlobSequenceDatum's for clients.
|
|
|
|
This expects that frontend clients need to access images based on the
|
|
run+tag+sample.
|
|
|
|
Args:
|
|
sorted_datum_list: a list of DataProvider's `BlobSequenceDatum`, sorted by
|
|
step. This can be produced via DataProvider's `read_blob_sequences`.
|
|
sample: zero-indexed integer for the requested sample.
|
|
|
|
Returns:
|
|
A list of `ImageStepDatum` (see http_api.md).
|
|
"""
|
|
# For images, ignore the first 2 items of a BlobSequenceDatum's values, which
|
|
# correspond to width, height.
|
|
index = sample + 2
|
|
step_data = []
|
|
for datum in sorted_datum_list:
|
|
if len(datum.values) <= index:
|
|
continue
|
|
|
|
step_data.append(
|
|
{
|
|
"step": datum.step,
|
|
"wallTime": datum.wall_time,
|
|
"imageId": datum.values[index].blob_key,
|
|
}
|
|
)
|
|
return step_data
|
|
|
|
|
|
def _get_tag_run_image_info(mapping):
|
|
"""Returns a map of tag names to run information.
|
|
|
|
Args:
|
|
mapping: the result of DataProvider's `list_blob_sequences`.
|
|
|
|
Returns:
|
|
A nested map from run strings to tag string to image info, where image
|
|
info is an object of form {"maxSamplesPerStep": num}. For example,
|
|
{
|
|
"reshaped": {
|
|
"test": {"maxSamplesPerStep": 1},
|
|
"train": {"maxSamplesPerStep": 1}
|
|
},
|
|
"convolved": {"test": {"maxSamplesPerStep": 50}},
|
|
}
|
|
"""
|
|
tag_run_image_info = collections.defaultdict(dict)
|
|
for (run, tag_to_content) in mapping.items():
|
|
for (tag, metadatum) in tag_to_content.items():
|
|
tag_run_image_info[tag][run] = {
|
|
"maxSamplesPerStep": metadatum.max_length - 2 # width, height
|
|
}
|
|
return dict(tag_run_image_info)
|
|
|
|
|
|
def _format_image_mapping(mapping):
|
|
"""Prepares an image mapping for client consumption.
|
|
|
|
Args:
|
|
mapping: the result of DataProvider's `list_blob_sequences`.
|
|
|
|
Returns:
|
|
A dict with the following fields:
|
|
tagRunSampledInfo: the return type of `_get_tag_run_image_info`
|
|
tagDescriptions: the return type of `_get_tag_description_info`
|
|
"""
|
|
return {
|
|
"tagDescriptions": _get_tag_to_description(mapping),
|
|
"tagRunSampledInfo": _get_tag_run_image_info(mapping),
|
|
}
|
|
|
|
|
|
class MetricsPlugin(base_plugin.TBPlugin):
|
|
"""Metrics Plugin for TensorBoard."""
|
|
|
|
plugin_name = metadata.PLUGIN_NAME
|
|
|
|
def __init__(self, context):
|
|
"""Instantiates MetricsPlugin.
|
|
|
|
Args:
|
|
context: A base_plugin.TBContext instance. MetricsLoader checks that
|
|
it contains a valid `data_provider`.
|
|
"""
|
|
self._data_provider = context.data_provider
|
|
|
|
# For histograms, use a round number + 1 since sampling includes both start
|
|
# and end steps, so N+1 samples corresponds to dividing the step sequence
|
|
# into N intervals.
|
|
sampling_hints = context.sampling_hints or {}
|
|
self._plugin_downsampling = {
|
|
"scalars": sampling_hints.get(scalar_metadata.PLUGIN_NAME, 1000),
|
|
"histograms": sampling_hints.get(
|
|
histogram_metadata.PLUGIN_NAME, 51
|
|
),
|
|
"images": sampling_hints.get(image_metadata.PLUGIN_NAME, 10),
|
|
}
|
|
self._scalar_version_checker = plugin_util._MetadataVersionChecker(
|
|
data_kind="scalar time series",
|
|
latest_known_version=0,
|
|
)
|
|
self._histogram_version_checker = plugin_util._MetadataVersionChecker(
|
|
data_kind="histogram time series",
|
|
latest_known_version=0,
|
|
)
|
|
self._image_version_checker = plugin_util._MetadataVersionChecker(
|
|
data_kind="image time series",
|
|
latest_known_version=0,
|
|
)
|
|
|
|
def frontend_metadata(self):
|
|
return base_plugin.FrontendMetadata(
|
|
is_ng_component=True, tab_name="Time Series"
|
|
)
|
|
|
|
def get_plugin_apps(self):
|
|
return {
|
|
"/tags": self._serve_tags,
|
|
"/timeSeries": self._serve_time_series,
|
|
"/imageData": self._serve_image_data,
|
|
}
|
|
|
|
def data_plugin_names(self):
|
|
return (
|
|
scalar_metadata.PLUGIN_NAME,
|
|
histogram_metadata.PLUGIN_NAME,
|
|
image_metadata.PLUGIN_NAME,
|
|
)
|
|
|
|
def is_active(self):
|
|
return False # 'data_plugin_names' suffices.
|
|
|
|
@wrappers.Request.application
|
|
def _serve_tags(self, request):
|
|
ctx = plugin_util.context(request.environ)
|
|
experiment = plugin_util.experiment_id(request.environ)
|
|
index = self._tags_impl(ctx, experiment=experiment)
|
|
return http_util.Respond(request, index, "application/json")
|
|
|
|
def _tags_impl(self, ctx, experiment=None):
|
|
"""Returns tag metadata for a given experiment's logged metrics.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
experiment: optional string ID of the request's experiment.
|
|
|
|
Returns:
|
|
A nested dict 'd' with keys in ("scalars", "histograms", "images")
|
|
and values being the return type of _format_*mapping.
|
|
"""
|
|
scalar_mapping = self._data_provider.list_scalars(
|
|
ctx,
|
|
experiment_id=experiment,
|
|
plugin_name=scalar_metadata.PLUGIN_NAME,
|
|
)
|
|
scalar_mapping = self._filter_by_version(
|
|
scalar_mapping,
|
|
scalar_metadata.parse_plugin_metadata,
|
|
self._scalar_version_checker,
|
|
)
|
|
|
|
histogram_mapping = self._data_provider.list_tensors(
|
|
ctx,
|
|
experiment_id=experiment,
|
|
plugin_name=histogram_metadata.PLUGIN_NAME,
|
|
)
|
|
if histogram_mapping is None:
|
|
histogram_mapping = {}
|
|
histogram_mapping = self._filter_by_version(
|
|
histogram_mapping,
|
|
histogram_metadata.parse_plugin_metadata,
|
|
self._histogram_version_checker,
|
|
)
|
|
|
|
image_mapping = self._data_provider.list_blob_sequences(
|
|
ctx,
|
|
experiment_id=experiment,
|
|
plugin_name=image_metadata.PLUGIN_NAME,
|
|
)
|
|
if image_mapping is None:
|
|
image_mapping = {}
|
|
image_mapping = self._filter_by_version(
|
|
image_mapping,
|
|
image_metadata.parse_plugin_metadata,
|
|
self._image_version_checker,
|
|
)
|
|
|
|
result = {}
|
|
result["scalars"] = _format_basic_mapping(scalar_mapping)
|
|
result["histograms"] = _format_basic_mapping(histogram_mapping)
|
|
result["images"] = _format_image_mapping(image_mapping)
|
|
return result
|
|
|
|
def _filter_by_version(self, mapping, parse_metadata, version_checker):
|
|
"""Filter `DataProvider.list_*` output by summary metadata version."""
|
|
result = {run: {} for run in mapping}
|
|
for (run, tag_to_content) in mapping.items():
|
|
for (tag, metadatum) in tag_to_content.items():
|
|
md = parse_metadata(metadatum.plugin_content)
|
|
if not version_checker.ok(md.version, run, tag):
|
|
continue
|
|
result[run][tag] = metadatum
|
|
return result
|
|
|
|
@wrappers.Request.application
|
|
def _serve_time_series(self, request):
|
|
ctx = plugin_util.context(request.environ)
|
|
experiment = plugin_util.experiment_id(request.environ)
|
|
if request.method == "POST":
|
|
series_requests_string = request.form.get("requests")
|
|
else:
|
|
series_requests_string = request.args.get("requests")
|
|
if not series_requests_string:
|
|
raise errors.InvalidArgumentError("Missing 'requests' field")
|
|
try:
|
|
series_requests = json.loads(series_requests_string)
|
|
except ValueError:
|
|
raise errors.InvalidArgumentError(
|
|
"Unable to parse 'requests' as JSON"
|
|
)
|
|
|
|
response = self._time_series_impl(ctx, experiment, series_requests)
|
|
return http_util.Respond(request, response, "application/json")
|
|
|
|
def _time_series_impl(self, ctx, experiment, series_requests):
|
|
"""Constructs a list of responses from a list of series requests.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
experiment: string ID of the request's experiment.
|
|
series_requests: a list of `TimeSeriesRequest` dicts (see http_api.md).
|
|
|
|
Returns:
|
|
A list of `TimeSeriesResponse` dicts (see http_api.md).
|
|
"""
|
|
responses = [
|
|
self._get_time_series(ctx, experiment, request)
|
|
for request in series_requests
|
|
]
|
|
return responses
|
|
|
|
def _create_base_response(self, series_request):
|
|
tag = series_request.get("tag")
|
|
run = series_request.get("run")
|
|
plugin = series_request.get("plugin")
|
|
sample = series_request.get("sample")
|
|
response = {"plugin": plugin, "tag": tag}
|
|
if isinstance(run, str):
|
|
response["run"] = run
|
|
if isinstance(sample, int):
|
|
response["sample"] = sample
|
|
|
|
return response
|
|
|
|
def _get_invalid_request_error(self, series_request):
|
|
tag = series_request.get("tag")
|
|
plugin = series_request.get("plugin")
|
|
run = series_request.get("run")
|
|
sample = series_request.get("sample")
|
|
|
|
if not isinstance(tag, str):
|
|
return "Missing tag"
|
|
|
|
if (
|
|
plugin != scalar_metadata.PLUGIN_NAME
|
|
and plugin != histogram_metadata.PLUGIN_NAME
|
|
and plugin != image_metadata.PLUGIN_NAME
|
|
):
|
|
return "Invalid plugin"
|
|
|
|
if plugin in _SINGLE_RUN_PLUGINS and not isinstance(run, str):
|
|
return "Missing run"
|
|
|
|
if plugin in _SAMPLED_PLUGINS and not isinstance(sample, int):
|
|
return "Missing sample"
|
|
|
|
return None
|
|
|
|
def _get_time_series(self, ctx, experiment, series_request):
|
|
"""Returns time series data for a given tag, plugin.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
experiment: string ID of the request's experiment.
|
|
series_request: a `TimeSeriesRequest` (see http_api.md).
|
|
|
|
Returns:
|
|
A `TimeSeriesResponse` dict (see http_api.md).
|
|
"""
|
|
tag = series_request.get("tag")
|
|
run = series_request.get("run")
|
|
plugin = series_request.get("plugin")
|
|
sample = series_request.get("sample")
|
|
response = self._create_base_response(series_request)
|
|
request_error = self._get_invalid_request_error(series_request)
|
|
if request_error:
|
|
response["error"] = request_error
|
|
return response
|
|
|
|
runs = [run] if run else None
|
|
run_to_series = None
|
|
if plugin == scalar_metadata.PLUGIN_NAME:
|
|
run_to_series = self._get_run_to_scalar_series(
|
|
ctx, experiment, tag, runs
|
|
)
|
|
|
|
if plugin == histogram_metadata.PLUGIN_NAME:
|
|
run_to_series = self._get_run_to_histogram_series(
|
|
ctx, experiment, tag, runs
|
|
)
|
|
|
|
if plugin == image_metadata.PLUGIN_NAME:
|
|
run_to_series = self._get_run_to_image_series(
|
|
ctx, experiment, tag, sample, runs
|
|
)
|
|
|
|
response["runToSeries"] = run_to_series
|
|
return response
|
|
|
|
def _get_run_to_scalar_series(self, ctx, experiment, tag, runs):
|
|
"""Builds a run-to-scalar-series dict for client consumption.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
experiment: a string experiment id.
|
|
tag: string of the requested tag.
|
|
runs: optional list of run names as strings.
|
|
|
|
Returns:
|
|
A map from string run names to `ScalarStepDatum` (see http_api.md).
|
|
"""
|
|
mapping = self._data_provider.read_scalars(
|
|
ctx,
|
|
experiment_id=experiment,
|
|
plugin_name=scalar_metadata.PLUGIN_NAME,
|
|
downsample=self._plugin_downsampling["scalars"],
|
|
run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]),
|
|
)
|
|
|
|
run_to_series = {}
|
|
for (result_run, tag_data) in mapping.items():
|
|
if tag not in tag_data:
|
|
continue
|
|
values = [
|
|
{
|
|
"wallTime": datum.wall_time,
|
|
"step": datum.step,
|
|
"value": datum.value,
|
|
}
|
|
for datum in tag_data[tag]
|
|
]
|
|
run_to_series[result_run] = values
|
|
|
|
return run_to_series
|
|
|
|
def _format_histogram_datum_bins(self, datum):
|
|
"""Formats a histogram datum's bins for client consumption.
|
|
|
|
Args:
|
|
datum: a DataProvider's TensorDatum.
|
|
|
|
Returns:
|
|
A list of `HistogramBin`s (see http_api.md).
|
|
"""
|
|
numpy_list = datum.numpy.tolist()
|
|
bins = [{"min": x[0], "max": x[1], "count": x[2]} for x in numpy_list]
|
|
return bins
|
|
|
|
def _get_run_to_histogram_series(self, ctx, experiment, tag, runs):
|
|
"""Builds a run-to-histogram-series dict for client consumption.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
experiment: a string experiment id.
|
|
tag: string of the requested tag.
|
|
runs: optional list of run names as strings.
|
|
|
|
Returns:
|
|
A map from string run names to `HistogramStepDatum` (see http_api.md).
|
|
"""
|
|
mapping = self._data_provider.read_tensors(
|
|
ctx,
|
|
experiment_id=experiment,
|
|
plugin_name=histogram_metadata.PLUGIN_NAME,
|
|
downsample=self._plugin_downsampling["histograms"],
|
|
run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]),
|
|
)
|
|
|
|
run_to_series = {}
|
|
for (result_run, tag_data) in mapping.items():
|
|
if tag not in tag_data:
|
|
continue
|
|
values = [
|
|
{
|
|
"wallTime": datum.wall_time,
|
|
"step": datum.step,
|
|
"bins": self._format_histogram_datum_bins(datum),
|
|
}
|
|
for datum in tag_data[tag]
|
|
]
|
|
run_to_series[result_run] = values
|
|
|
|
return run_to_series
|
|
|
|
def _get_run_to_image_series(self, ctx, experiment, tag, sample, runs):
|
|
"""Builds a run-to-image-series dict for client consumption.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
experiment: a string experiment id.
|
|
tag: string of the requested tag.
|
|
sample: zero-indexed integer for the requested sample.
|
|
runs: optional list of run names as strings.
|
|
|
|
Returns:
|
|
A `RunToSeries` dict (see http_api.md).
|
|
"""
|
|
mapping = self._data_provider.read_blob_sequences(
|
|
ctx,
|
|
experiment_id=experiment,
|
|
plugin_name=image_metadata.PLUGIN_NAME,
|
|
downsample=self._plugin_downsampling["images"],
|
|
run_tag_filter=provider.RunTagFilter(runs, tags=[tag]),
|
|
)
|
|
|
|
run_to_series = {}
|
|
for (result_run, tag_data) in mapping.items():
|
|
if tag not in tag_data:
|
|
continue
|
|
blob_sequence_datum_list = tag_data[tag]
|
|
series = _format_image_blob_sequence_datum(
|
|
blob_sequence_datum_list, sample
|
|
)
|
|
if series:
|
|
run_to_series[result_run] = series
|
|
|
|
return run_to_series
|
|
|
|
@wrappers.Request.application
|
|
def _serve_image_data(self, request):
|
|
"""Serves an individual image."""
|
|
ctx = plugin_util.context(request.environ)
|
|
blob_key = request.args["imageId"]
|
|
if not blob_key:
|
|
raise errors.InvalidArgumentError("Missing 'imageId' field")
|
|
|
|
(data, content_type) = self._image_data_impl(ctx, blob_key)
|
|
return http_util.Respond(request, data, content_type)
|
|
|
|
def _image_data_impl(self, ctx, blob_key):
|
|
"""Gets the image data for a blob key.
|
|
|
|
Args:
|
|
ctx: A `tensorboard.context.RequestContext` value.
|
|
blob_key: a string identifier for a DataProvider blob.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
data: a raw bytestring of the requested image's contents.
|
|
content_type: a string HTTP content type.
|
|
"""
|
|
data = self._data_provider.read_blob(ctx, blob_key=blob_key)
|
|
image_type = imghdr.what(None, data)
|
|
content_type = _IMGHDR_TO_MIMETYPE.get(
|
|
image_type, _DEFAULT_IMAGE_MIMETYPE
|
|
)
|
|
return (data, content_type)
|