3RNN/Lib/site-packages/tensorboard/backend/event_processing/data_provider.py
2024-05-26 19:49:15 +02:00

539 lines
19 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bridge from event multiplexer storage to generic data APIs."""
import base64
import collections
import json
import random
from tensorboard import errors
from tensorboard.compat.proto import summary_pb2
from tensorboard.data import provider
from tensorboard.util import tb_logging
from tensorboard.util import tensor_util
logger = tb_logging.get_logger()
class MultiplexerDataProvider(provider.DataProvider):
def __init__(self, multiplexer, logdir):
"""Trivial initializer.
Args:
multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note:
not a boring old `event_multiplexer.EventMultiplexer`).
logdir: The log directory from which data is being read. Only used
cosmetically. Should be a `str`.
"""
self._multiplexer = multiplexer
self._logdir = logdir
def __str__(self):
return "MultiplexerDataProvider(logdir=%r)" % self._logdir
def _validate_context(self, ctx):
if type(ctx).__name__ != "RequestContext":
raise TypeError("ctx must be a RequestContext; got: %r" % (ctx,))
def _validate_experiment_id(self, experiment_id):
# This data provider doesn't consume the experiment ID at all, but
# as a courtesy to callers we require that it be a valid string, to
# help catch usage errors.
if not isinstance(experiment_id, str):
raise TypeError(
"experiment_id must be %r, but got %r: %r"
% (str, type(experiment_id), experiment_id)
)
def _validate_downsample(self, downsample):
if downsample is None:
raise TypeError("`downsample` required but not given")
if isinstance(downsample, int):
return # OK
raise TypeError(
"`downsample` must be an int, but got %r: %r"
% (type(downsample), downsample)
)
def _test_run_tag(self, run_tag_filter, run, tag):
runs = run_tag_filter.runs
if runs is not None and run not in runs:
return False
tags = run_tag_filter.tags
if tags is not None and tag not in tags:
return False
return True
def _get_first_event_timestamp(self, run_name):
try:
return self._multiplexer.FirstEventTimestamp(run_name)
except ValueError as e:
return None
def experiment_metadata(self, ctx=None, *, experiment_id):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
return provider.ExperimentMetadata(data_location=self._logdir)
def list_plugins(self, ctx=None, *, experiment_id):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
# Note: This result may include plugins that only have time
# series with `DATA_CLASS_UNKNOWN`, which will not actually be
# accessible via `list_*` or read_*`. This is inconsistent with
# the specification for `list_plugins`, but the bug should be
# mostly harmless.
return self._multiplexer.ActivePlugins()
def list_runs(self, ctx=None, *, experiment_id):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
return [
provider.Run(
run_id=run, # use names as IDs
run_name=run,
start_time=self._get_first_event_timestamp(run),
)
for run in self._multiplexer.Runs()
]
def list_scalars(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
return self._list(provider.ScalarTimeSeries, index)
def read_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
return self._read(_convert_scalar_event, index, downsample)
def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
run_tag_to_last_scalar_datum = collections.defaultdict(dict)
for (run, tags_for_run) in index.items():
for (tag, metadata) in tags_for_run.items():
events = self._multiplexer.Tensors(run, tag)
if events:
run_tag_to_last_scalar_datum[run][
tag
] = _convert_scalar_event(events[-1])
return run_tag_to_last_scalar_datum
def list_tensors(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
)
return self._list(provider.TensorTimeSeries, index)
def read_tensors(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR
)
return self._read(_convert_tensor_event, index, downsample)
def _index(self, plugin_name, run_tag_filter, data_class_filter):
"""List time series and metadata matching the given filters.
This is like `_list`, but doesn't traverse `Tensors(...)` to
compute metadata that's not always needed.
Args:
plugin_name: A string plugin name filter (required).
run_tag_filter: An `provider.RunTagFilter`, or `None`.
data_class_filter: A `summary_pb2.DataClass` filter (required).
Returns:
A nested dict `d` such that `d[run][tag]` is a
`SummaryMetadata` proto.
"""
if run_tag_filter is None:
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
runs = run_tag_filter.runs
tags = run_tag_filter.tags
# Optimization for a common case, reading a single time series.
if runs and len(runs) == 1 and tags and len(tags) == 1:
(run,) = runs
(tag,) = tags
try:
metadata = self._multiplexer.SummaryMetadata(run, tag)
except KeyError:
return {}
all_metadata = {run: {tag: metadata}}
else:
all_metadata = self._multiplexer.AllSummaryMetadata()
result = {}
for (run, tag_to_metadata) in all_metadata.items():
if runs is not None and run not in runs:
continue
result_for_run = {}
for (tag, metadata) in tag_to_metadata.items():
if tags is not None and tag not in tags:
continue
if metadata.data_class != data_class_filter:
continue
if metadata.plugin_data.plugin_name != plugin_name:
continue
result[run] = result_for_run
result_for_run[tag] = metadata
return result
def _list(self, construct_time_series, index):
"""Helper to list scalar or tensor time series.
Args:
construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
index: The result of `self._index(...)`.
Returns:
A list of objects of type given by `construct_time_series`,
suitable to be returned from `list_scalars` or `list_tensors`.
"""
result = {}
for (run, tag_to_metadata) in index.items():
result_for_run = {}
result[run] = result_for_run
for (tag, summary_metadata) in tag_to_metadata.items():
max_step = None
max_wall_time = None
for event in self._multiplexer.Tensors(run, tag):
if max_step is None or max_step < event.step:
max_step = event.step
if max_wall_time is None or max_wall_time < event.wall_time:
max_wall_time = event.wall_time
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
result_for_run[tag] = construct_time_series(
max_step=max_step,
max_wall_time=max_wall_time,
plugin_content=summary_metadata.plugin_data.content,
description=summary_metadata.summary_description,
display_name=summary_metadata.display_name,
)
return result
def _read(self, convert_event, index, downsample):
"""Helper to read scalar or tensor data from the multiplexer.
Args:
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
either `provider.ScalarDatum` or `provider.TensorDatum`.
index: The result of `self._index(...)`.
downsample: Non-negative `int`; how many samples to return per
time series.
Returns:
A dict of dicts of values returned by `convert_event` calls,
suitable to be returned from `read_scalars` or `read_tensors`.
"""
result = {}
for (run, tags_for_run) in index.items():
result_for_run = {}
result[run] = result_for_run
for (tag, metadata) in tags_for_run.items():
events = self._multiplexer.Tensors(run, tag)
data = [convert_event(e) for e in events]
result_for_run[tag] = _downsample(data, downsample)
return result
def list_blob_sequences(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
)
result = {}
for (run, tag_to_metadata) in index.items():
result_for_run = {}
result[run] = result_for_run
for (tag, metadata) in tag_to_metadata.items():
max_step = None
max_wall_time = None
max_length = None
for event in self._multiplexer.Tensors(run, tag):
if max_step is None or max_step < event.step:
max_step = event.step
if max_wall_time is None or max_wall_time < event.wall_time:
max_wall_time = event.wall_time
length = _tensor_size(event.tensor_proto)
if max_length is None or length > max_length:
max_length = length
result_for_run[tag] = provider.BlobSequenceTimeSeries(
max_step=max_step,
max_wall_time=max_wall_time,
max_length=max_length,
plugin_content=metadata.plugin_data.content,
description=metadata.summary_description,
display_name=metadata.display_name,
)
return result
def read_blob_sequences(
self,
ctx=None,
*,
experiment_id,
plugin_name,
downsample=None,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE
)
result = {}
for (run, tags) in index.items():
result_for_run = {}
result[run] = result_for_run
for tag in tags:
events = self._multiplexer.Tensors(run, tag)
data_by_step = {}
for event in events:
if event.step in data_by_step:
continue
data_by_step[event.step] = _convert_blob_sequence_event(
experiment_id, plugin_name, run, tag, event
)
data = [datum for (step, datum) in sorted(data_by_step.items())]
result_for_run[tag] = _downsample(data, downsample)
return result
def read_blob(self, ctx=None, *, blob_key):
self._validate_context(ctx)
(
unused_experiment_id,
plugin_name,
run,
tag,
step,
index,
) = _decode_blob_key(blob_key)
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
if summary_metadata.data_class != summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
raise errors.NotFoundError(blob_key)
tensor_events = self._multiplexer.Tensors(run, tag)
# In case of multiple events at this step, take first (arbitrary).
matching_step = next((e for e in tensor_events if e.step == step), None)
if not matching_step:
raise errors.NotFoundError("%s: no such step %r" % (blob_key, step))
tensor = tensor_util.make_ndarray(matching_step.tensor_proto)
return tensor[index]
# TODO(davidsoergel): deduplicate with other implementations
def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index):
"""Generate a blob key: a short, URL-safe string identifying a blob.
A blob can be located using a set of integer and string fields; here we
serialize these to allow passing the data through a URL. Specifically, we
1) construct a tuple of the arguments in order; 2) represent that as an
ascii-encoded JSON string (without whitespace); and 3) take the URL-safe
base64 encoding of that, with no padding. For example:
1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0)
2) JSON: ["some_id","graphs","train","graph_def",2,0]
3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K
Args:
experiment_id: a string ID identifying an experiment.
plugin_name: string
run: string
tag: string
step: int
index: int
Returns:
A URL-safe base64-encoded string representing the provided arguments.
"""
# Encodes the blob key as a URL-safe string, as required by the
# `BlobReference` API in `tensorboard/data/provider.py`, because these keys
# may be used to construct URLs for retrieving blobs.
stringified = json.dumps(
(experiment_id, plugin_name, run, tag, step, index),
separators=(",", ":"),
)
bytesified = stringified.encode("ascii")
encoded = base64.urlsafe_b64encode(bytesified)
return encoded.decode("ascii").rstrip("=")
# Any changes to this function need not be backward-compatible, even though
# the current encoding was used to generate URLs. The reason is that the
# generated URLs are not considered permalinks: they need to be valid only
# within the context of the session that created them (via the matching
# `_encode_blob_key` function above).
def _decode_blob_key(key):
"""Decode a blob key produced by `_encode_blob_key` into component fields.
Args:
key: a blob key, as generated by `_encode_blob_key`.
Returns:
A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types
matching the arguments of `_encode_blob_key`.
"""
decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4.
stringified = decoded.decode("ascii")
(experiment_id, plugin_name, run, tag, step, index) = json.loads(
stringified
)
return (experiment_id, plugin_name, run, tag, step, index)
def _convert_scalar_event(event):
"""Helper for `read_scalars`."""
return provider.ScalarDatum(
step=event.step,
wall_time=event.wall_time,
value=tensor_util.make_ndarray(event.tensor_proto).item(),
)
def _convert_tensor_event(event):
"""Helper for `read_tensors`."""
return provider.TensorDatum(
step=event.step,
wall_time=event.wall_time,
numpy=tensor_util.make_ndarray(event.tensor_proto),
)
def _convert_blob_sequence_event(experiment_id, plugin_name, run, tag, event):
"""Helper for `read_blob_sequences`."""
num_blobs = _tensor_size(event.tensor_proto)
values = tuple(
provider.BlobReference(
_encode_blob_key(
experiment_id,
plugin_name,
run,
tag,
event.step,
idx,
)
)
for idx in range(num_blobs)
)
return provider.BlobSequenceDatum(
wall_time=event.wall_time,
step=event.step,
values=values,
)
def _tensor_size(tensor_proto):
"""Compute the number of elements in a tensor.
This does not deserialize the full tensor contents.
Args:
tensor_proto: A `tensorboard.compat.proto.tensor_pb2.TensorProto`.
Returns:
A non-negative `int`.
"""
# This is the same logic that `tensor_util.make_ndarray` uses to
# compute the size, but without the actual buffer copies.
result = 1
for dim in tensor_proto.tensor_shape.dim:
result *= dim.size
return result
def _downsample(xs, k):
"""Downsample `xs` to at most `k` elements.
If `k` is larger than `xs`, then the contents of `xs` itself will be
returned. If `k` is smaller than `xs`, the last element of `xs` will
always be included (unless `k` is `0`) and the preceding elements
will be selected uniformly at random.
This differs from `random.sample` in that it returns a subsequence
(i.e., order is preserved) and that it permits `k > len(xs)`.
The random number generator will always be `random.Random(0)`, so
this function is deterministic (within a Python process).
Args:
xs: A sequence (`collections.abc.Sequence`).
k: A non-negative integer.
Returns:
A new list whose elements are a subsequence of `xs` of length
`min(k, len(xs))` and that is guaranteed to include the last
element of `xs`, uniformly selected among such subsequences.
"""
if k > len(xs):
return list(xs)
if k == 0:
return []
indices = random.Random(0).sample(range(len(xs) - 1), k - 1)
indices.sort()
indices += [len(xs) - 1]
return [xs[i] for i in indices]