338 lines
12 KiB
Python
338 lines
12 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""The TensorBoard Graphs plugin."""
|
||
|
|
||
|
|
||
|
import json
|
||
|
from werkzeug import wrappers
|
||
|
|
||
|
from tensorboard import errors
|
||
|
from tensorboard import plugin_util
|
||
|
from tensorboard.backend import http_util
|
||
|
from tensorboard.backend import process_graph
|
||
|
from tensorboard.compat.proto import config_pb2
|
||
|
from tensorboard.compat.proto import graph_pb2
|
||
|
from tensorboard.data import provider
|
||
|
from tensorboard.plugins import base_plugin
|
||
|
from tensorboard.plugins.graph import graph_util
|
||
|
from tensorboard.plugins.graph import keras_util
|
||
|
from tensorboard.plugins.graph import metadata
|
||
|
from tensorboard.util import tb_logging
|
||
|
|
||
|
logger = tb_logging.get_logger()
|
||
|
|
||
|
|
||
|
class GraphsPlugin(base_plugin.TBPlugin):
|
||
|
"""Graphs Plugin for TensorBoard."""
|
||
|
|
||
|
plugin_name = metadata.PLUGIN_NAME
|
||
|
|
||
|
def __init__(self, context):
|
||
|
"""Instantiates GraphsPlugin via TensorBoard core.
|
||
|
|
||
|
Args:
|
||
|
context: A base_plugin.TBContext instance.
|
||
|
"""
|
||
|
self._data_provider = context.data_provider
|
||
|
|
||
|
def get_plugin_apps(self):
|
||
|
return {
|
||
|
"/graph": self.graph_route,
|
||
|
"/info": self.info_route,
|
||
|
"/run_metadata": self.run_metadata_route,
|
||
|
}
|
||
|
|
||
|
def is_active(self):
|
||
|
"""The graphs plugin is active iff any run has a graph or metadata."""
|
||
|
return False # `list_plugins` as called by TB core suffices
|
||
|
|
||
|
def data_plugin_names(self):
|
||
|
return (
|
||
|
metadata.PLUGIN_NAME,
|
||
|
metadata.PLUGIN_NAME_RUN_METADATA,
|
||
|
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
|
||
|
metadata.PLUGIN_NAME_KERAS_MODEL,
|
||
|
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
|
||
|
)
|
||
|
|
||
|
def frontend_metadata(self):
|
||
|
return base_plugin.FrontendMetadata(
|
||
|
element_name="tf-graph-dashboard",
|
||
|
# TODO(@chihuahua): Reconcile this setting with Health Pills.
|
||
|
disable_reload=True,
|
||
|
)
|
||
|
|
||
|
def info_impl(self, ctx, experiment=None):
|
||
|
"""Returns a dict of all runs and their data availabilities."""
|
||
|
result = {}
|
||
|
|
||
|
def add_row_item(run, tag=None):
|
||
|
run_item = result.setdefault(
|
||
|
run,
|
||
|
{
|
||
|
"run": run,
|
||
|
"tags": {},
|
||
|
# A run-wide GraphDef of ops.
|
||
|
"run_graph": False,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
tag_item = None
|
||
|
if tag:
|
||
|
tag_item = run_item.get("tags").setdefault(
|
||
|
tag,
|
||
|
{
|
||
|
"tag": tag,
|
||
|
"conceptual_graph": False,
|
||
|
# A tagged GraphDef of ops.
|
||
|
"op_graph": False,
|
||
|
"profile": False,
|
||
|
},
|
||
|
)
|
||
|
return (run_item, tag_item)
|
||
|
|
||
|
mapping = self._data_provider.list_blob_sequences(
|
||
|
ctx,
|
||
|
experiment_id=experiment,
|
||
|
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
|
||
|
)
|
||
|
for (run_name, tags) in mapping.items():
|
||
|
for (tag, tag_data) in tags.items():
|
||
|
# The Summary op is defined in TensorFlow and does not use a stringified proto
|
||
|
# as a content of plugin data. It contains single string that denotes a version.
|
||
|
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790
|
||
|
if tag_data.plugin_content != b"1":
|
||
|
logger.warning(
|
||
|
"Ignoring unrecognizable version of RunMetadata."
|
||
|
)
|
||
|
continue
|
||
|
(_, tag_item) = add_row_item(run_name, tag)
|
||
|
tag_item["op_graph"] = True
|
||
|
|
||
|
# Tensors associated with plugin name metadata.PLUGIN_NAME_RUN_METADATA
|
||
|
# contain both op graph and profile information.
|
||
|
mapping = self._data_provider.list_blob_sequences(
|
||
|
ctx,
|
||
|
experiment_id=experiment,
|
||
|
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA,
|
||
|
)
|
||
|
for (run_name, tags) in mapping.items():
|
||
|
for (tag, tag_data) in tags.items():
|
||
|
if tag_data.plugin_content != b"1":
|
||
|
logger.warning(
|
||
|
"Ignoring unrecognizable version of RunMetadata."
|
||
|
)
|
||
|
continue
|
||
|
(_, tag_item) = add_row_item(run_name, tag)
|
||
|
tag_item["profile"] = True
|
||
|
tag_item["op_graph"] = True
|
||
|
|
||
|
# Tensors associated with plugin name metadata.PLUGIN_NAME_KERAS_MODEL
|
||
|
# contain serialized Keras model in JSON format.
|
||
|
mapping = self._data_provider.list_blob_sequences(
|
||
|
ctx,
|
||
|
experiment_id=experiment,
|
||
|
plugin_name=metadata.PLUGIN_NAME_KERAS_MODEL,
|
||
|
)
|
||
|
for (run_name, tags) in mapping.items():
|
||
|
for (tag, tag_data) in tags.items():
|
||
|
if tag_data.plugin_content != b"1":
|
||
|
logger.warning(
|
||
|
"Ignoring unrecognizable version of RunMetadata."
|
||
|
)
|
||
|
continue
|
||
|
(_, tag_item) = add_row_item(run_name, tag)
|
||
|
tag_item["conceptual_graph"] = True
|
||
|
|
||
|
mapping = self._data_provider.list_blob_sequences(
|
||
|
ctx,
|
||
|
experiment_id=experiment,
|
||
|
plugin_name=metadata.PLUGIN_NAME,
|
||
|
)
|
||
|
for (run_name, tags) in mapping.items():
|
||
|
if metadata.RUN_GRAPH_NAME in tags:
|
||
|
(run_item, _) = add_row_item(run_name, None)
|
||
|
run_item["run_graph"] = True
|
||
|
|
||
|
# Top level `Event.tagged_run_metadata` represents profile data only.
|
||
|
mapping = self._data_provider.list_blob_sequences(
|
||
|
ctx,
|
||
|
experiment_id=experiment,
|
||
|
plugin_name=metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
|
||
|
)
|
||
|
for (run_name, tags) in mapping.items():
|
||
|
for tag in tags:
|
||
|
(_, tag_item) = add_row_item(run_name, tag)
|
||
|
tag_item["profile"] = True
|
||
|
|
||
|
return result
|
||
|
|
||
|
def _read_blob(self, ctx, experiment, plugin_names, run, tag):
|
||
|
for plugin_name in plugin_names:
|
||
|
blob_sequences = self._data_provider.read_blob_sequences(
|
||
|
ctx,
|
||
|
experiment_id=experiment,
|
||
|
plugin_name=plugin_name,
|
||
|
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
|
||
|
downsample=1,
|
||
|
)
|
||
|
blob_sequence_data = blob_sequences.get(run, {}).get(tag, ())
|
||
|
try:
|
||
|
blob_ref = blob_sequence_data[0].values[0]
|
||
|
except IndexError:
|
||
|
continue
|
||
|
return self._data_provider.read_blob(
|
||
|
ctx, blob_key=blob_ref.blob_key
|
||
|
)
|
||
|
raise errors.NotFoundError()
|
||
|
|
||
|
def graph_impl(
|
||
|
self,
|
||
|
ctx,
|
||
|
run,
|
||
|
tag,
|
||
|
is_conceptual,
|
||
|
experiment=None,
|
||
|
limit_attr_size=None,
|
||
|
large_attrs_key=None,
|
||
|
):
|
||
|
"""Result of the form `(body, mime_type)`; may raise `NotFound`."""
|
||
|
if is_conceptual:
|
||
|
keras_model_config = json.loads(
|
||
|
self._read_blob(
|
||
|
ctx,
|
||
|
experiment,
|
||
|
[metadata.PLUGIN_NAME_KERAS_MODEL],
|
||
|
run,
|
||
|
tag,
|
||
|
)
|
||
|
)
|
||
|
graph = keras_util.keras_model_to_graph_def(keras_model_config)
|
||
|
|
||
|
elif tag is None:
|
||
|
graph_raw = self._read_blob(
|
||
|
ctx,
|
||
|
experiment,
|
||
|
[metadata.PLUGIN_NAME],
|
||
|
run,
|
||
|
metadata.RUN_GRAPH_NAME,
|
||
|
)
|
||
|
graph = graph_pb2.GraphDef.FromString(graph_raw)
|
||
|
|
||
|
else:
|
||
|
# Op graph: could be either of two plugins. (Cf. `info_impl`.)
|
||
|
plugins = [
|
||
|
metadata.PLUGIN_NAME_RUN_METADATA,
|
||
|
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
|
||
|
]
|
||
|
raw_run_metadata = self._read_blob(
|
||
|
ctx, experiment, plugins, run, tag
|
||
|
)
|
||
|
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
|
||
|
graph = graph_util.merge_graph_defs(
|
||
|
[
|
||
|
func_graph.pre_optimization_graph
|
||
|
for func_graph in run_metadata.function_graphs
|
||
|
]
|
||
|
)
|
||
|
|
||
|
# This next line might raise a ValueError if the limit parameters
|
||
|
# are invalid (size is negative, size present but key absent, etc.).
|
||
|
process_graph.prepare_graph_for_ui(
|
||
|
graph, limit_attr_size, large_attrs_key
|
||
|
)
|
||
|
return (str(graph), "text/x-protobuf") # pbtxt
|
||
|
|
||
|
def run_metadata_impl(self, ctx, experiment, run, tag):
|
||
|
"""Result of the form `(body, mime_type)`; may raise `NotFound`."""
|
||
|
# Profile graph: could be either of two plugins. (Cf. `info_impl`.)
|
||
|
plugins = [
|
||
|
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA,
|
||
|
metadata.PLUGIN_NAME_RUN_METADATA,
|
||
|
]
|
||
|
raw_run_metadata = self._read_blob(ctx, experiment, plugins, run, tag)
|
||
|
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
|
||
|
return (str(run_metadata), "text/x-protobuf") # pbtxt
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def info_route(self, request):
|
||
|
ctx = plugin_util.context(request.environ)
|
||
|
experiment = plugin_util.experiment_id(request.environ)
|
||
|
info = self.info_impl(ctx, experiment)
|
||
|
return http_util.Respond(request, info, "application/json")
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def graph_route(self, request):
|
||
|
"""Given a single run, return the graph definition in protobuf
|
||
|
format."""
|
||
|
ctx = plugin_util.context(request.environ)
|
||
|
experiment = plugin_util.experiment_id(request.environ)
|
||
|
run = request.args.get("run")
|
||
|
tag = request.args.get("tag")
|
||
|
conceptual_arg = request.args.get("conceptual", False)
|
||
|
is_conceptual = True if conceptual_arg == "true" else False
|
||
|
|
||
|
if run is None:
|
||
|
return http_util.Respond(
|
||
|
request, 'query parameter "run" is required', "text/plain", 400
|
||
|
)
|
||
|
|
||
|
limit_attr_size = request.args.get("limit_attr_size", None)
|
||
|
if limit_attr_size is not None:
|
||
|
try:
|
||
|
limit_attr_size = int(limit_attr_size)
|
||
|
except ValueError:
|
||
|
return http_util.Respond(
|
||
|
request,
|
||
|
"query parameter `limit_attr_size` must be an integer",
|
||
|
"text/plain",
|
||
|
400,
|
||
|
)
|
||
|
|
||
|
large_attrs_key = request.args.get("large_attrs_key", None)
|
||
|
|
||
|
try:
|
||
|
result = self.graph_impl(
|
||
|
ctx,
|
||
|
run,
|
||
|
tag,
|
||
|
is_conceptual,
|
||
|
experiment,
|
||
|
limit_attr_size,
|
||
|
large_attrs_key,
|
||
|
)
|
||
|
except ValueError as e:
|
||
|
return http_util.Respond(request, e.message, "text/plain", code=400)
|
||
|
(body, mime_type) = result
|
||
|
return http_util.Respond(request, body, mime_type)
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def run_metadata_route(self, request):
|
||
|
"""Given a tag and a run, return the session.run() metadata."""
|
||
|
ctx = plugin_util.context(request.environ)
|
||
|
experiment = plugin_util.experiment_id(request.environ)
|
||
|
tag = request.args.get("tag")
|
||
|
run = request.args.get("run")
|
||
|
if tag is None:
|
||
|
return http_util.Respond(
|
||
|
request, 'query parameter "tag" is required', "text/plain", 400
|
||
|
)
|
||
|
if run is None:
|
||
|
return http_util.Respond(
|
||
|
request, 'query parameter "run" is required', "text/plain", 400
|
||
|
)
|
||
|
(body, mime_type) = self.run_metadata_impl(ctx, experiment, run, tag)
|
||
|
return http_util.Respond(request, body, mime_type)
|