# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Bridge from event multiplexer storage to generic data APIs.""" import base64 import json import random from tensorboard import errors from tensorboard.compat.proto import summary_pb2 from tensorboard.data import provider from tensorboard.util import tb_logging from tensorboard.util import tensor_util logger = tb_logging.get_logger() class MultiplexerDataProvider(provider.DataProvider): def __init__(self, multiplexer, logdir): """Trivial initializer. Args: multiplexer: A `plugin_event_multiplexer.EventMultiplexer` (note: not a boring old `event_multiplexer.EventMultiplexer`). logdir: The log directory from which data is being read. Only used cosmetically. Should be a `str`. """ self._multiplexer = multiplexer self._logdir = logdir def __str__(self): return "MultiplexerDataProvider(logdir=%r)" % self._logdir def _validate_context(self, ctx): if type(ctx).__name__ != "RequestContext": raise TypeError("ctx must be a RequestContext; got: %r" % (ctx,)) def _validate_experiment_id(self, experiment_id): # This data provider doesn't consume the experiment ID at all, but # as a courtesy to callers we require that it be a valid string, to # help catch usage errors. if not isinstance(experiment_id, str): raise TypeError( "experiment_id must be %r, but got %r: %r" % (str, type(experiment_id), experiment_id) ) def _validate_downsample(self, downsample): if downsample is None: raise TypeError("`downsample` required but not given") if isinstance(downsample, int): return # OK raise TypeError( "`downsample` must be an int, but got %r: %r" % (type(downsample), downsample) ) def _test_run_tag(self, run_tag_filter, run, tag): runs = run_tag_filter.runs if runs is not None and run not in runs: return False tags = run_tag_filter.tags if tags is not None and tag not in tags: return False return True def _get_first_event_timestamp(self, run_name): try: return self._multiplexer.FirstEventTimestamp(run_name) except ValueError as e: return None def experiment_metadata(self, ctx=None, *, experiment_id): self._validate_context(ctx) self._validate_experiment_id(experiment_id) return provider.ExperimentMetadata(data_location=self._logdir) def list_plugins(self, ctx=None, *, experiment_id): self._validate_context(ctx) self._validate_experiment_id(experiment_id) # Note: This result may include plugins that only have time # series with `DATA_CLASS_UNKNOWN`, which will not actually be # accessible via `list_*` or read_*`. This is inconsistent with # the specification for `list_plugins`, but the bug should be # mostly harmless. return self._multiplexer.ActivePlugins() def list_runs(self, ctx=None, *, experiment_id): self._validate_context(ctx) self._validate_experiment_id(experiment_id) return [ provider.Run( run_id=run, # use names as IDs run_name=run, start_time=self._get_first_event_timestamp(run), ) for run in self._multiplexer.Runs() ] def list_scalars( self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR ) return self._list(provider.ScalarTimeSeries, index) def read_scalars( self, ctx=None, *, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) self._validate_downsample(downsample) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR ) return self._read(_convert_scalar_event, index, downsample) def list_tensors( self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR ) return self._list(provider.TensorTimeSeries, index) def read_tensors( self, ctx=None, *, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) self._validate_downsample(downsample) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_TENSOR ) return self._read(_convert_tensor_event, index, downsample) def _index(self, plugin_name, run_tag_filter, data_class_filter): """List time series and metadata matching the given filters. This is like `_list`, but doesn't traverse `Tensors(...)` to compute metadata that's not always needed. Args: plugin_name: A string plugin name filter (required). run_tag_filter: An `provider.RunTagFilter`, or `None`. data_class_filter: A `summary_pb2.DataClass` filter (required). Returns: A nested dict `d` such that `d[run][tag]` is a `SummaryMetadata` proto. """ if run_tag_filter is None: run_tag_filter = provider.RunTagFilter(runs=None, tags=None) runs = run_tag_filter.runs tags = run_tag_filter.tags # Optimization for a common case, reading a single time series. if runs and len(runs) == 1 and tags and len(tags) == 1: (run,) = runs (tag,) = tags try: metadata = self._multiplexer.SummaryMetadata(run, tag) except KeyError: return {} all_metadata = {run: {tag: metadata}} else: all_metadata = self._multiplexer.AllSummaryMetadata() result = {} for (run, tag_to_metadata) in all_metadata.items(): if runs is not None and run not in runs: continue result_for_run = {} for (tag, metadata) in tag_to_metadata.items(): if tags is not None and tag not in tags: continue if metadata.data_class != data_class_filter: continue if metadata.plugin_data.plugin_name != plugin_name: continue result[run] = result_for_run result_for_run[tag] = metadata return result def _list(self, construct_time_series, index): """Helper to list scalar or tensor time series. Args: construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`. index: The result of `self._index(...)`. Returns: A list of objects of type given by `construct_time_series`, suitable to be returned from `list_scalars` or `list_tensors`. """ result = {} for (run, tag_to_metadata) in index.items(): result_for_run = {} result[run] = result_for_run for (tag, summary_metadata) in tag_to_metadata.items(): max_step = None max_wall_time = None for event in self._multiplexer.Tensors(run, tag): if max_step is None or max_step < event.step: max_step = event.step if max_wall_time is None or max_wall_time < event.wall_time: max_wall_time = event.wall_time summary_metadata = self._multiplexer.SummaryMetadata(run, tag) result_for_run[tag] = construct_time_series( max_step=max_step, max_wall_time=max_wall_time, plugin_content=summary_metadata.plugin_data.content, description=summary_metadata.summary_description, display_name=summary_metadata.display_name, ) return result def _read(self, convert_event, index, downsample): """Helper to read scalar or tensor data from the multiplexer. Args: convert_event: Takes `plugin_event_accumulator.TensorEvent` to either `provider.ScalarDatum` or `provider.TensorDatum`. index: The result of `self._index(...)`. downsample: Non-negative `int`; how many samples to return per time series. Returns: A dict of dicts of values returned by `convert_event` calls, suitable to be returned from `read_scalars` or `read_tensors`. """ result = {} for (run, tags_for_run) in index.items(): result_for_run = {} result[run] = result_for_run for (tag, metadata) in tags_for_run.items(): events = self._multiplexer.Tensors(run, tag) data = [convert_event(e) for e in events] result_for_run[tag] = _downsample(data, downsample) return result def list_blob_sequences( self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE ) result = {} for (run, tag_to_metadata) in index.items(): result_for_run = {} result[run] = result_for_run for (tag, metadata) in tag_to_metadata.items(): max_step = None max_wall_time = None max_length = None for event in self._multiplexer.Tensors(run, tag): if max_step is None or max_step < event.step: max_step = event.step if max_wall_time is None or max_wall_time < event.wall_time: max_wall_time = event.wall_time length = _tensor_size(event.tensor_proto) if max_length is None or length > max_length: max_length = length result_for_run[tag] = provider.BlobSequenceTimeSeries( max_step=max_step, max_wall_time=max_wall_time, max_length=max_length, plugin_content=metadata.plugin_data.content, description=metadata.summary_description, display_name=metadata.display_name, ) return result def read_blob_sequences( self, ctx=None, *, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): self._validate_context(ctx) self._validate_experiment_id(experiment_id) self._validate_downsample(downsample) index = self._index( plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_BLOB_SEQUENCE ) result = {} for (run, tags) in index.items(): result_for_run = {} result[run] = result_for_run for tag in tags: events = self._multiplexer.Tensors(run, tag) data_by_step = {} for event in events: if event.step in data_by_step: continue data_by_step[event.step] = _convert_blob_sequence_event( experiment_id, plugin_name, run, tag, event ) data = [datum for (step, datum) in sorted(data_by_step.items())] result_for_run[tag] = _downsample(data, downsample) return result def read_blob(self, ctx=None, *, blob_key): self._validate_context(ctx) ( unused_experiment_id, plugin_name, run, tag, step, index, ) = _decode_blob_key(blob_key) summary_metadata = self._multiplexer.SummaryMetadata(run, tag) if summary_metadata.data_class != summary_pb2.DATA_CLASS_BLOB_SEQUENCE: raise errors.NotFoundError(blob_key) tensor_events = self._multiplexer.Tensors(run, tag) # In case of multiple events at this step, take first (arbitrary). matching_step = next((e for e in tensor_events if e.step == step), None) if not matching_step: raise errors.NotFoundError("%s: no such step %r" % (blob_key, step)) tensor = tensor_util.make_ndarray(matching_step.tensor_proto) return tensor[index] # TODO(davidsoergel): deduplicate with other implementations def _encode_blob_key(experiment_id, plugin_name, run, tag, step, index): """Generate a blob key: a short, URL-safe string identifying a blob. A blob can be located using a set of integer and string fields; here we serialize these to allow passing the data through a URL. Specifically, we 1) construct a tuple of the arguments in order; 2) represent that as an ascii-encoded JSON string (without whitespace); and 3) take the URL-safe base64 encoding of that, with no padding. For example: 1) Tuple: ("some_id", "graphs", "train", "graph_def", 2, 0) 2) JSON: ["some_id","graphs","train","graph_def",2,0] 3) base64: WyJzb21lX2lkIiwiZ3JhcGhzIiwidHJhaW4iLCJncmFwaF9kZWYiLDIsMF0K Args: experiment_id: a string ID identifying an experiment. plugin_name: string run: string tag: string step: int index: int Returns: A URL-safe base64-encoded string representing the provided arguments. """ # Encodes the blob key as a URL-safe string, as required by the # `BlobReference` API in `tensorboard/data/provider.py`, because these keys # may be used to construct URLs for retrieving blobs. stringified = json.dumps( (experiment_id, plugin_name, run, tag, step, index), separators=(",", ":"), ) bytesified = stringified.encode("ascii") encoded = base64.urlsafe_b64encode(bytesified) return encoded.decode("ascii").rstrip("=") # Any changes to this function need not be backward-compatible, even though # the current encoding was used to generate URLs. The reason is that the # generated URLs are not considered permalinks: they need to be valid only # within the context of the session that created them (via the matching # `_encode_blob_key` function above). def _decode_blob_key(key): """Decode a blob key produced by `_encode_blob_key` into component fields. Args: key: a blob key, as generated by `_encode_blob_key`. Returns: A tuple of `(experiment_id, plugin_name, run, tag, step, index)`, with types matching the arguments of `_encode_blob_key`. """ decoded = base64.urlsafe_b64decode(key + "==") # pad past a multiple of 4. stringified = decoded.decode("ascii") (experiment_id, plugin_name, run, tag, step, index) = json.loads( stringified ) return (experiment_id, plugin_name, run, tag, step, index) def _convert_scalar_event(event): """Helper for `read_scalars`.""" return provider.ScalarDatum( step=event.step, wall_time=event.wall_time, value=tensor_util.make_ndarray(event.tensor_proto).item(), ) def _convert_tensor_event(event): """Helper for `read_tensors`.""" return provider.TensorDatum( step=event.step, wall_time=event.wall_time, numpy=tensor_util.make_ndarray(event.tensor_proto), ) def _convert_blob_sequence_event(experiment_id, plugin_name, run, tag, event): """Helper for `read_blob_sequences`.""" num_blobs = _tensor_size(event.tensor_proto) values = tuple( provider.BlobReference( _encode_blob_key( experiment_id, plugin_name, run, tag, event.step, idx, ) ) for idx in range(num_blobs) ) return provider.BlobSequenceDatum( wall_time=event.wall_time, step=event.step, values=values, ) def _tensor_size(tensor_proto): """Compute the number of elements in a tensor. This does not deserialize the full tensor contents. Args: tensor_proto: A `tensorboard.compat.proto.tensor_pb2.TensorProto`. Returns: A non-negative `int`. """ # This is the same logic that `tensor_util.make_ndarray` uses to # compute the size, but without the actual buffer copies. result = 1 for dim in tensor_proto.tensor_shape.dim: result *= dim.size return result def _downsample(xs, k): """Downsample `xs` to at most `k` elements. If `k` is larger than `xs`, then the contents of `xs` itself will be returned. If `k` is smaller than `xs`, the last element of `xs` will always be included (unless `k` is `0`) and the preceding elements will be selected uniformly at random. This differs from `random.sample` in that it returns a subsequence (i.e., order is preserved) and that it permits `k > len(xs)`. The random number generator will always be `random.Random(0)`, so this function is deterministic (within a Python process). Args: xs: A sequence (`collections.abc.Sequence`). k: A non-negative integer. Returns: A new list whose elements are a subsequence of `xs` of length `min(k, len(xs))` and that is guaranteed to include the last element of `xs`, uniformly selected among such subsequences. """ if k > len(xs): return list(xs) if k == 0: return [] indices = random.Random(0).sample(range(len(xs) - 1), k - 1) indices.sort() indices += [len(xs) - 1] return [xs[i] for i in indices]