539 lines
19 KiB
Python
539 lines
19 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Bridge from event multiplexer storage to generic data APIs."""
|
|
|
|
|
|
import base64
|
|
import collections
|
|
import json
|
|
import random
|
|
|
|
from tensorboard import errors
|
|
from tensorboard.compat.proto import summary_pb2
|
|
from tensorboard.data import provider
|
|
from tensorboard.util import tb_logging
|
|
from tensorboard.util import tensor_util
|
|
|
|
logger = tb_logging.get_logger()
|
|
|
|
|
|
class MultiplexerDataProvider(provider.DataProvider):
|
|
def __init__(self, multiplexer, logdir):
|
|
"""Trivial initializer.
|
|
|
|
Args:
|
|
multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note:
|
|
not a boring old `event_multiplexer.EventMultiplexer`).
|
|
logdir: The log directory from which data is being read. Only used
|
|
cosmetically. Should be a `str`.
|
|
"""
|
|
self._multiplexer = multiplexer
|
|
self._logdir = logdir
|
|
|
|
def __str__(self):
|
|
return "MultiplexerDataProvider(logdir=%r)" % self._logdir
|
|
|
|
def _validate_context(self, ctx):
|
|
if type(ctx).__name__ != "RequestContext":
|
|
raise TypeError("ctx must be a RequestContext; got: %r" % (ctx,))
|
|
|
|
def _validate_experiment_id(self, experiment_id):
|
|
# This data provider doesn't consume the experiment ID at all, but
|
|
# as a courtesy to callers we require that it be a valid string, to
|
|
# help catch usage errors.
|
|
if not isinstance(experiment_id, str):
|
|
raise TypeError(
|
|
"experiment_id must be %r, but got %r: %r"
|
|
% (str, type(experiment_id), experiment_id)
|
|
)
|
|
|
|
def _validate_downsample(self, downsample):
|
|
if downsample is None:
|
|
raise TypeError("`downsample` required but not given")
|
|
if isinstance(downsample, int):
|
|
return # OK
|
|
raise TypeError(
|
|
"`downsample` must be an int, but got %r: %r"
|
|
% (type(downsample), downsample)
|
|
)
|
|
|
|
def _test_run_tag(self, run_tag_filter, run, tag):
|
|
runs = run_tag_filter.runs
|
|
if runs is not None and run not in runs:
|
|
return False
|
|
tags = run_tag_filter.tags
|
|
if tags is not None and tag not in tags:
|
|
return False
|
|
return True
|
|
|
|
def _get_first_event_timestamp(self, run_name):
|
|
try:
|
|
return self._multiplexer.FirstEventTimestamp(run_name)
|
|
except ValueError as e:
|
|
return None
|
|
|
|
def experiment_metadata(self, ctx=None, *, experiment_id):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
return provider.ExperimentMetadata(data_location=self._logdir)
|
|
|
|
def list_plugins(self, ctx=None, *, experiment_id):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
# Note: This result may include plugins that only have time
|
|
# series with `DATA_CLASS_UNKNOWN`, which will not actually be
|
|
# accessible via `list_*` or read_*`. This is inconsistent with
|
|
# the specification for `list_plugins`, but the bug should be
|
|
# mostly harmless.
|
|
return self._multiplexer.ActivePlugins()
|
|
|
|
def list_runs(self, ctx=None, *, experiment_id):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
return [
|
|
provider.Run(
|
|
run_id=run, # use names as IDs
|
|
run_name=run,
|
|
start_time=self._get_first_event_timestamp(run),
|
|
)
|
|
for run in self._multiplexer.Runs()
|
|
]
|
|
|
|
def list_scalars(
|
|
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
|
|
)
|
|
return self._list(provider.ScalarTimeSeries, index)
|
|
|
|
def read_scalars(
|
|
self,
|
|
ctx=None,
|
|
*,
|
|
experiment_id,
|
|
plugin_name,
|
|
downsample=None,
|
|
run_tag_filter=None,
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
self._validate_downsample(downsample)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
|
|
)
|
|
return self._read(_convert_scalar_event, index, downsample)
|
|
|
|
def read_last_scalars(
|
|
self,
|
|
ctx=None,
|
|
*,
|
|
experiment_id,
|
|
plugin_name,
|
|
run_tag_filter=None,
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
|
|
)
|
|
run_tag_to_last_scalar_datum = collections.defaultdict(dict)
|
|
for (run, tags_for_run) in index.items():
|
|
for (tag, metadata) in tags_for_run.items():
|
|
events = self._multiplexer.Tensors(run, tag)
|
|
if events:
|
|
run_tag_to_last_scalar_datum[run][
|
|
tag
|
|
] = _convert_scalar_event(events[-1])
|
|
|
|
return run_tag_to_last_scalar_datum
|
|
|
|
def list_tensors(
|
|
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
|
|
)
|
|
return self._list(provider.TensorTimeSeries, index)
|
|
|
|
def read_tensors(
|
|
self,
|
|
ctx=None,
|
|
*,
|
|
experiment_id,
|
|
plugin_name,
|
|
downsample=None,
|
|
run_tag_filter=None,
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
self._validate_downsample(downsample)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
|
|
)
|
|
return self._read(_convert_tensor_event, index, downsample)
|
|
|
|
def _index(self, plugin_name, run_tag_filter, data_class_filter):
|
|
"""List time series and metadata matching the given filters.
|
|
|
|
This is like `_list`, but doesn't traverse `Tensors(...)` to
|
|
compute metadata that's not always needed.
|
|
|
|
Args:
|
|
plugin_name: A string plugin name filter (required).
|
|
run_tag_filter: An `provider.RunTagFilter`, or `None`.
|
|
data_class_filter: A `summary_pb2.DataClass` filter (required).
|
|
|
|
Returns:
|
|
A nested dict `d` such that `d[run][tag]` is a
|
|
`SummaryMetadata` proto.
|
|
"""
|
|
if run_tag_filter is None:
|
|
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
|
|
runs = run_tag_filter.runs
|
|
tags = run_tag_filter.tags
|
|
|
|
# Optimization for a common case, reading a single time series.
|
|
if runs and len(runs) == 1 and tags and len(tags) == 1:
|
|
(run,) = runs
|
|
(tag,) = tags
|
|
try:
|
|
metadata = self._multiplexer.SummaryMetadata(run, tag)
|
|
except KeyError:
|
|
return {}
|
|
all_metadata = {run: {tag: metadata}}
|
|
else:
|
|
all_metadata = self._multiplexer.AllSummaryMetadata()
|
|
|
|
result = {}
|
|
for (run, tag_to_metadata) in all_metadata.items():
|
|
if runs is not None and run not in runs:
|
|
continue
|
|
result_for_run = {}
|
|
for (tag, metadata) in tag_to_metadata.items():
|
|
if tags is not None and tag not in tags:
|
|
continue
|
|
if metadata.data_class != data_class_filter:
|
|
continue
|
|
if metadata.plugin_data.plugin_name != plugin_name:
|
|
continue
|
|
result[run] = result_for_run
|
|
result_for_run[tag] = metadata
|
|
|
|
return result
|
|
|
|
def _list(self, construct_time_series, index):
|
|
"""Helper to list scalar or tensor time series.
|
|
|
|
Args:
|
|
construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
|
|
index: The result of `self._index(...)`.
|
|
|
|
Returns:
|
|
A list of objects of type given by `construct_time_series`,
|
|
suitable to be returned from `list_scalars` or `list_tensors`.
|
|
"""
|
|
result = {}
|
|
for (run, tag_to_metadata) in index.items():
|
|
result_for_run = {}
|
|
result[run] = result_for_run
|
|
for (tag, summary_metadata) in tag_to_metadata.items():
|
|
max_step = None
|
|
max_wall_time = None
|
|
for event in self._multiplexer.Tensors(run, tag):
|
|
if max_step is None or max_step < event.step:
|
|
max_step = event.step
|
|
if max_wall_time is None or max_wall_time < event.wall_time:
|
|
max_wall_time = event.wall_time
|
|
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
|
|
result_for_run[tag] = construct_time_series(
|
|
max_step=max_step,
|
|
max_wall_time=max_wall_time,
|
|
plugin_content=summary_metadata.plugin_data.content,
|
|
description=summary_metadata.summary_description,
|
|
display_name=summary_metadata.display_name,
|
|
)
|
|
return result
|
|
|
|
def _read(self, convert_event, index, downsample):
|
|
"""Helper to read scalar or tensor data from the multiplexer.
|
|
|
|
Args:
|
|
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
|
|
either `provider.ScalarDatum` or `provider.TensorDatum`.
|
|
index: The result of `self._index(...)`.
|
|
downsample: Non-negative `int`; how many samples to return per
|
|
time series.
|
|
|
|
Returns:
|
|
A dict of dicts of values returned by `convert_event` calls,
|
|
suitable to be returned from `read_scalars` or `read_tensors`.
|
|
"""
|
|
result = {}
|
|
for (run, tags_for_run) in index.items():
|
|
result_for_run = {}
|
|
result[run] = result_for_run
|
|
for (tag, metadata) in tags_for_run.items():
|
|
events = self._multiplexer.Tensors(run, tag)
|
|
data = [convert_event(e) for e in events]
|
|
result_for_run[tag] = _downsample(data, downsample)
|
|
return result
|
|
|
|
def list_blob_sequences(
|
|
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
|
|
)
|
|
result = {}
|
|
for (run, tag_to_metadata) in index.items():
|
|
result_for_run = {}
|
|
result[run] = result_for_run
|
|
for (tag, metadata) in tag_to_metadata.items():
|
|
max_step = None
|
|
max_wall_time = None
|
|
max_length = None
|
|
for event in self._multiplexer.Tensors(run, tag):
|
|
if max_step is None or max_step < event.step:
|
|
max_step = event.step
|
|
if max_wall_time is None or max_wall_time < event.wall_time:
|
|
max_wall_time = event.wall_time
|
|
length = _tensor_size(event.tensor_proto)
|
|
if max_length is None or length > max_length:
|
|
max_length = length
|
|
result_for_run[tag] = provider.BlobSequenceTimeSeries(
|
|
max_step=max_step,
|
|
max_wall_time=max_wall_time,
|
|
max_length=max_length,
|
|
plugin_content=metadata.plugin_data.content,
|
|
description=metadata.summary_description,
|
|
display_name=metadata.display_name,
|
|
)
|
|
return result
|
|
|
|
def read_blob_sequences(
|
|
self,
|
|
ctx=None,
|
|
*,
|
|
experiment_id,
|
|
plugin_name,
|
|
downsample=None,
|
|
run_tag_filter=None,
|
|
):
|
|
self._validate_context(ctx)
|
|
self._validate_experiment_id(experiment_id)
|
|
self._validate_downsample(downsample)
|
|
index = self._index(
|
|
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
|
|
)
|
|
result = {}
|
|
for (run, tags) in index.items():
|
|
result_for_run = {}
|
|
result[run] = result_for_run
|
|
for tag in tags:
|
|
events = self._multiplexer.Tensors(run, tag)
|
|
data_by_step = {}
|
|
for event in events:
|
|
if event.step in data_by_step:
|
|
continue
|
|
data_by_step[event.step] = _convert_blob_sequence_event(
|
|
experiment_id, plugin_name, run, tag, event
|
|
)
|
|
data = [datum for (step, datum) in sorted(data_by_step.items())]
|
|
result_for_run[tag] = _downsample(data, downsample)
|
|
return result
|
|
|
|
def read_blob(self, ctx=None, *, blob_key):
|
|
self._validate_context(ctx)
|
|
(
|
|
unused_experiment_id,
|
|
plugin_name,
|
|
run,
|
|
tag,
|
|
step,
|
|
index,
|
|
) = _decode_blob_key(blob_key)
|
|
|
|
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
|
|
if summary_metadata.data_class != summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
|
|
raise errors.NotFoundError(blob_key)
|
|
tensor_events = self._multiplexer.Tensors(run, tag)
|
|
# In case of multiple events at this step, take first (arbitrary).
|
|
matching_step = next((e for e in tensor_events if e.step == step), None)
|
|
if not matching_step:
|
|
raise errors.NotFoundError("%s: no such step %r" % (blob_key, step))
|
|
tensor = tensor_util.make_ndarray(matching_step.tensor_proto)
|
|
return tensor[index]
|
|
|
|
|
|
# TODO(davidsoergel): deduplicate with other implementations
|
|
def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
|
|
"""Generate a blob key: a short, URL-safe string identifying a blob.
|
|
|
|
A blob can be located using a set of integer and string fields; here we
|
|
serialize these to allow passing the data through a URL. Specifically, we
|
|
1) construct a tuple of the arguments in order; 2) represent that as an
|
|
ascii-encoded JSON string (without whitespace); and 3) take the URL-safe
|
|
base64 encoding of that, with no padding. For example:
|
|
|
|
1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0)
|
|
2) JSON: ["some_id","graphs","train","graph_def",2,0]
|
|
3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K
|
|
|
|
Args:
|
|
experiment_id: a string ID identifying an experiment.
|
|
plugin_name: string
|
|
run: string
|
|
tag: string
|
|
step: int
|
|
index: int
|
|
|
|
Returns:
|
|
A URL-safe base64-encoded string representing the provided arguments.
|
|
"""
|
|
# Encodes the blob key as a URL-safe string, as required by the
|
|
# `BlobReference` API in `tensorboard/data/provider.py`, because these keys
|
|
# may be used to construct URLs for retrieving blobs.
|
|
stringified = json.dumps(
|
|
(experiment_id, plugin_name, run, tag, step, index),
|
|
separators=(",", ":"),
|
|
)
|
|
bytesified = stringified.encode("ascii")
|
|
encoded = base64.urlsafe_b64encode(bytesified)
|
|
return encoded.decode("ascii").rstrip("=")
|
|
|
|
|
|
# Any changes to this function need not be backward-compatible, even though
|
|
# the current encoding was used to generate URLs. The reason is that the
|
|
# generated URLs are not considered permalinks: they need to be valid only
|
|
# within the context of the session that created them (via the matching
|
|
# `_encode_blob_key` function above).
|
|
def _decode_blob_key(key):
|
|
"""Decode a blob key produced by `_encode_blob_key` into component fields.
|
|
|
|
Args:
|
|
key: a blob key, as generated by `_encode_blob_key`.
|
|
|
|
Returns:
|
|
A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types
|
|
matching the arguments of `_encode_blob_key`.
|
|
"""
|
|
decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4.
|
|
stringified = decoded.decode("ascii")
|
|
(experiment_id, plugin_name, run, tag, step, index) = json.loads(
|
|
stringified
|
|
)
|
|
return (experiment_id, plugin_name, run, tag, step, index)
|
|
|
|
|
|
def _convert_scalar_event(event):
|
|
"""Helper for `read_scalars`."""
|
|
return provider.ScalarDatum(
|
|
step=event.step,
|
|
wall_time=event.wall_time,
|
|
value=tensor_util.make_ndarray(event.tensor_proto).item(),
|
|
)
|
|
|
|
|
|
def _convert_tensor_event(event):
|
|
"""Helper for `read_tensors`."""
|
|
return provider.TensorDatum(
|
|
step=event.step,
|
|
wall_time=event.wall_time,
|
|
numpy=tensor_util.make_ndarray(event.tensor_proto),
|
|
)
|
|
|
|
|
|
def _convert_blob_sequence_event(experiment_id, plugin_name, run, tag, event):
|
|
"""Helper for `read_blob_sequences`."""
|
|
num_blobs = _tensor_size(event.tensor_proto)
|
|
values = tuple(
|
|
provider.BlobReference(
|
|
_encode_blob_key(
|
|
experiment_id,
|
|
plugin_name,
|
|
run,
|
|
tag,
|
|
event.step,
|
|
idx,
|
|
)
|
|
)
|
|
for idx in range(num_blobs)
|
|
)
|
|
return provider.BlobSequenceDatum(
|
|
wall_time=event.wall_time,
|
|
step=event.step,
|
|
values=values,
|
|
)
|
|
|
|
|
|
def _tensor_size(tensor_proto):
|
|
"""Compute the number of elements in a tensor.
|
|
|
|
This does not deserialize the full tensor contents.
|
|
|
|
Args:
|
|
tensor_proto: A `tensorboard.compat.proto.tensor_pb2.TensorProto`.
|
|
|
|
Returns:
|
|
A non-negative `int`.
|
|
"""
|
|
# This is the same logic that `tensor_util.make_ndarray` uses to
|
|
# compute the size, but without the actual buffer copies.
|
|
result = 1
|
|
for dim in tensor_proto.tensor_shape.dim:
|
|
result *= dim.size
|
|
return result
|
|
|
|
|
|
def _downsample(xs, k):
|
|
"""Downsample `xs` to at most `k` elements.
|
|
|
|
If `k` is larger than `xs`, then the contents of `xs` itself will be
|
|
returned. If `k` is smaller than `xs`, the last element of `xs` will
|
|
always be included (unless `k` is `0`) and the preceding elements
|
|
will be selected uniformly at random.
|
|
|
|
This differs from `random.sample` in that it returns a subsequence
|
|
(i.e., order is preserved) and that it permits `k > len(xs)`.
|
|
|
|
The random number generator will always be `random.Random(0)`, so
|
|
this function is deterministic (within a Python process).
|
|
|
|
Args:
|
|
xs: A sequence (`collections.abc.Sequence`).
|
|
k: A non-negative integer.
|
|
|
|
Returns:
|
|
A new list whose elements are a subsequence of `xs` of length
|
|
`min(k, len(xs))` and that is guaranteed to include the last
|
|
element of `xs`, uniformly selected among such subsequences.
|
|
"""
|
|
|
|
if k > len(xs):
|
|
return list(xs)
|
|
if k == 0:
|
|
return []
|
|
indices = random.Random(0).sample(range(len(xs) - 1), k - 1)
|
|
indices.sort()
|
|
indices += [len(xs) - 1]
|
|
return [xs[i] for i in indices]
|