# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The TensorBoard Graphs plugin.""" import json from werkzeug import wrappers from tensorboard import errors from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.backend import process_graph from tensorboard.compat.proto import config_pb2 from tensorboard.compat.proto import graph_pb2 from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.graph import graph_util from tensorboard.plugins.graph import keras_util from tensorboard.plugins.graph import metadata from tensorboard.util import tb_logging logger = tb_logging.get_logger() class GraphsPlugin(base_plugin.TBPlugin): """Graphs Plugin for TensorBoard.""" plugin_name = metadata.PLUGIN_NAME def __init__(self, context): """Instantiates GraphsPlugin via TensorBoard core. Args: context: A base_plugin.TBContext instance. """ self._data_provider = context.data_provider def get_plugin_apps(self): return { "/graph": self.graph_route, "/info": self.info_route, "/run_metadata": self.run_metadata_route, } def is_active(self): """The graphs plugin is active iff any run has a graph or metadata.""" return False # `list_plugins` as called by TB core suffices def data_plugin_names(self): return ( metadata.PLUGIN_NAME, metadata.PLUGIN_NAME_RUN_METADATA, metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, metadata.PLUGIN_NAME_KERAS_MODEL, metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, ) def frontend_metadata(self): return base_plugin.FrontendMetadata( element_name="tf-graph-dashboard", # TODO(@chihuahua): Reconcile this setting with Health Pills. disable_reload=True, ) def info_impl(self, ctx, experiment=None): """Returns a dict of all runs and their data availabilities.""" result = {} def add_row_item(run, tag=None): run_item = result.setdefault( run, { "run": run, "tags": {}, # A run-wide GraphDef of ops. "run_graph": False, }, ) tag_item = None if tag: tag_item = run_item.get("tags").setdefault( tag, { "tag": tag, "conceptual_graph": False, # A tagged GraphDef of ops. "op_graph": False, "profile": False, }, ) return (run_item, tag_item) mapping = self._data_provider.list_blob_sequences( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, ) for (run_name, tags) in mapping.items(): for (tag, tag_data) in tags.items(): # The Summary op is defined in TensorFlow and does not use a stringified proto # as a content of plugin data. It contains single string that denotes a version. # https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790 if tag_data.plugin_content != b"1": logger.warning( "Ignoring unrecognizable version of RunMetadata." ) continue (_, tag_item) = add_row_item(run_name, tag) tag_item["op_graph"] = True # Tensors associated with plugin name metadata.PLUGIN_NAME_RUN_METADATA # contain both op graph and profile information. mapping = self._data_provider.list_blob_sequences( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME_RUN_METADATA, ) for (run_name, tags) in mapping.items(): for (tag, tag_data) in tags.items(): if tag_data.plugin_content != b"1": logger.warning( "Ignoring unrecognizable version of RunMetadata." ) continue (_, tag_item) = add_row_item(run_name, tag) tag_item["profile"] = True tag_item["op_graph"] = True # Tensors associated with plugin name metadata.PLUGIN_NAME_KERAS_MODEL # contain serialized Keras model in JSON format. mapping = self._data_provider.list_blob_sequences( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME_KERAS_MODEL, ) for (run_name, tags) in mapping.items(): for (tag, tag_data) in tags.items(): if tag_data.plugin_content != b"1": logger.warning( "Ignoring unrecognizable version of RunMetadata." ) continue (_, tag_item) = add_row_item(run_name, tag) tag_item["conceptual_graph"] = True mapping = self._data_provider.list_blob_sequences( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME, ) for (run_name, tags) in mapping.items(): if metadata.RUN_GRAPH_NAME in tags: (run_item, _) = add_row_item(run_name, None) run_item["run_graph"] = True # Top level `Event.tagged_run_metadata` represents profile data only. mapping = self._data_provider.list_blob_sequences( ctx, experiment_id=experiment, plugin_name=metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, ) for (run_name, tags) in mapping.items(): for tag in tags: (_, tag_item) = add_row_item(run_name, tag) tag_item["profile"] = True return result def _read_blob(self, ctx, experiment, plugin_names, run, tag): for plugin_name in plugin_names: blob_sequences = self._data_provider.read_blob_sequences( ctx, experiment_id=experiment, plugin_name=plugin_name, run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), downsample=1, ) blob_sequence_data = blob_sequences.get(run, {}).get(tag, ()) try: blob_ref = blob_sequence_data[0].values[0] except IndexError: continue return self._data_provider.read_blob( ctx, blob_key=blob_ref.blob_key ) raise errors.NotFoundError() def graph_impl( self, ctx, run, tag, is_conceptual, experiment=None, limit_attr_size=None, large_attrs_key=None, ): """Result of the form `(body, mime_type)`; may raise `NotFound`.""" if is_conceptual: keras_model_config = json.loads( self._read_blob( ctx, experiment, [metadata.PLUGIN_NAME_KERAS_MODEL], run, tag, ) ) graph = keras_util.keras_model_to_graph_def(keras_model_config) elif tag is None: graph_raw = self._read_blob( ctx, experiment, [metadata.PLUGIN_NAME], run, metadata.RUN_GRAPH_NAME, ) graph = graph_pb2.GraphDef.FromString(graph_raw) else: # Op graph: could be either of two plugins. (Cf. `info_impl`.) plugins = [ metadata.PLUGIN_NAME_RUN_METADATA, metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, ] raw_run_metadata = self._read_blob( ctx, experiment, plugins, run, tag ) run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata) graph = graph_util.merge_graph_defs( [ func_graph.pre_optimization_graph for func_graph in run_metadata.function_graphs ] ) # This next line might raise a ValueError if the limit parameters # are invalid (size is negative, size present but key absent, etc.). process_graph.prepare_graph_for_ui( graph, limit_attr_size, large_attrs_key ) return (str(graph), "text/x-protobuf") # pbtxt def run_metadata_impl(self, ctx, experiment, run, tag): """Result of the form `(body, mime_type)`; may raise `NotFound`.""" # Profile graph: could be either of two plugins. (Cf. `info_impl`.) plugins = [ metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, metadata.PLUGIN_NAME_RUN_METADATA, ] raw_run_metadata = self._read_blob(ctx, experiment, plugins, run, tag) run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata) return (str(run_metadata), "text/x-protobuf") # pbtxt @wrappers.Request.application def info_route(self, request): ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) info = self.info_impl(ctx, experiment) return http_util.Respond(request, info, "application/json") @wrappers.Request.application def graph_route(self, request): """Given a single run, return the graph definition in protobuf format.""" ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) run = request.args.get("run") tag = request.args.get("tag") conceptual_arg = request.args.get("conceptual", False) is_conceptual = True if conceptual_arg == "true" else False if run is None: return http_util.Respond( request, 'query parameter "run" is required', "text/plain", 400 ) limit_attr_size = request.args.get("limit_attr_size", None) if limit_attr_size is not None: try: limit_attr_size = int(limit_attr_size) except ValueError: return http_util.Respond( request, "query parameter `limit_attr_size` must be an integer", "text/plain", 400, ) large_attrs_key = request.args.get("large_attrs_key", None) try: result = self.graph_impl( ctx, run, tag, is_conceptual, experiment, limit_attr_size, large_attrs_key, ) except ValueError as e: return http_util.Respond(request, e.message, "text/plain", code=400) (body, mime_type) = result return http_util.Respond(request, body, mime_type) @wrappers.Request.application def run_metadata_route(self, request): """Given a tag and a run, return the session.run() metadata.""" ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) tag = request.args.get("tag") run = request.args.get("run") if tag is None: return http_util.Respond( request, 'query parameter "tag" is required', "text/plain", 400 ) if run is None: return http_util.Respond( request, 'query parameter "run" is required', "text/plain", 400 ) (body, mime_type) = self.run_metadata_impl(ctx, experiment, run, tag) return http_util.Respond(request, body, mime_type)