550 lines
16 KiB
Python
550 lines
16 KiB
Python
#!/usr/bin/env python
|
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""This tool creates an html visualization of a TensorFlow Lite graph.
|
|
|
|
Example usage:
|
|
|
|
python visualize.py foo.tflite foo.html
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import numpy as np
|
|
|
|
# pylint: disable=g-import-not-at-top
|
|
if not os.path.splitext(__file__)[0].endswith(
|
|
os.path.join("tflite_runtime", "visualize")):
|
|
# This file is part of tensorflow package.
|
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
|
else:
|
|
# This file is part of tflite_runtime package.
|
|
from tflite_runtime import schema_py_generated as schema_fb
|
|
|
|
# A CSS description for making the visualizer
|
|
_CSS = """
|
|
<html>
|
|
<head>
|
|
<style>
|
|
body {font-family: sans-serif; background-color: #fa0;}
|
|
table {background-color: #eca;}
|
|
th {background-color: black; color: white;}
|
|
h1 {
|
|
background-color: ffaa00;
|
|
padding:5px;
|
|
color: black;
|
|
}
|
|
|
|
svg {
|
|
margin: 10px;
|
|
border: 2px;
|
|
border-style: solid;
|
|
border-color: black;
|
|
background: white;
|
|
}
|
|
|
|
div {
|
|
border-radius: 5px;
|
|
background-color: #fec;
|
|
padding:5px;
|
|
margin:5px;
|
|
}
|
|
|
|
.tooltip {color: blue;}
|
|
.tooltip .tooltipcontent {
|
|
visibility: hidden;
|
|
color: black;
|
|
background-color: yellow;
|
|
padding: 5px;
|
|
border-radius: 4px;
|
|
position: absolute;
|
|
z-index: 1;
|
|
}
|
|
.tooltip:hover .tooltipcontent {
|
|
visibility: visible;
|
|
}
|
|
|
|
.edges line {
|
|
stroke: #333;
|
|
}
|
|
|
|
text {
|
|
font-weight: bold;
|
|
}
|
|
|
|
.nodes text {
|
|
color: black;
|
|
pointer-events: none;
|
|
font-family: sans-serif;
|
|
font-size: 11px;
|
|
}
|
|
</style>
|
|
|
|
<script src="https://d3js.org/d3.v4.min.js"></script>
|
|
|
|
</head>
|
|
<body>
|
|
"""
|
|
|
|
_D3_HTML_TEMPLATE = """
|
|
<script>
|
|
function buildGraph() {
|
|
// Build graph data
|
|
var graph = %s;
|
|
|
|
var svg = d3.select("#subgraph%d")
|
|
var width = svg.attr("width");
|
|
var height = svg.attr("height");
|
|
// Make the graph scrollable.
|
|
svg = svg.call(d3.zoom().on("zoom", function() {
|
|
svg.attr("transform", d3.event.transform);
|
|
})).append("g");
|
|
|
|
|
|
var color = d3.scaleOrdinal(d3.schemeDark2);
|
|
|
|
var simulation = d3.forceSimulation()
|
|
.force("link", d3.forceLink().id(function(d) {return d.id;}))
|
|
.force("charge", d3.forceManyBody())
|
|
.force("center", d3.forceCenter(0.5 * width, 0.5 * height));
|
|
|
|
var edge = svg.append("g").attr("class", "edges").selectAll("line")
|
|
.data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
|
|
|
|
// Make the node group
|
|
var node = svg.selectAll(".nodes")
|
|
.data(graph.nodes)
|
|
.enter().append("g")
|
|
.attr("x", function(d){return d.x})
|
|
.attr("y", function(d){return d.y})
|
|
.attr("transform", function(d) {
|
|
return "translate( " + d.x + ", " + d.y + ")"
|
|
})
|
|
.attr("class", "nodes")
|
|
.call(d3.drag()
|
|
.on("start", function(d) {
|
|
if(!d3.event.active) simulation.alphaTarget(1.0).restart();
|
|
d.fx = d.x;d.fy = d.y;
|
|
})
|
|
.on("drag", function(d) {
|
|
d.fx = d3.event.x; d.fy = d3.event.y;
|
|
})
|
|
.on("end", function(d) {
|
|
if (!d3.event.active) simulation.alphaTarget(0);
|
|
d.fx = d.fy = null;
|
|
}));
|
|
// Within the group, draw a box for the node position and text
|
|
// on the side.
|
|
|
|
var node_width = 150;
|
|
var node_height = 30;
|
|
|
|
node.append("rect")
|
|
.attr("r", "5px")
|
|
.attr("width", node_width)
|
|
.attr("height", node_height)
|
|
.attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
|
|
.attr("stroke", "#000000")
|
|
.attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
|
|
node.append("text")
|
|
.text(function(d) { return d.name; })
|
|
.attr("x", 5)
|
|
.attr("y", 20)
|
|
.attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
|
|
// Setup force parameters and update position callback
|
|
|
|
|
|
var node = svg.selectAll(".nodes")
|
|
.data(graph.nodes);
|
|
|
|
// Bind the links
|
|
var name_to_g = {}
|
|
node.each(function(data, index, nodes) {
|
|
console.log(data.id)
|
|
name_to_g[data.id] = this;
|
|
});
|
|
|
|
function proc(w, t) {
|
|
return parseInt(w.getAttribute(t));
|
|
}
|
|
edge.attr("d", function(d) {
|
|
function lerp(t, a, b) {
|
|
return (1.0-t) * a + t * b;
|
|
}
|
|
var x1 = proc(name_to_g[d.source],"x") + node_width /2;
|
|
var y1 = proc(name_to_g[d.source],"y") + node_height;
|
|
var x2 = proc(name_to_g[d.target],"x") + node_width /2;
|
|
var y2 = proc(name_to_g[d.target],"y");
|
|
var s = "M " + x1 + " " + y1
|
|
+ " C " + x1 + " " + lerp(.5, y1, y2)
|
|
+ " " + x2 + " " + lerp(.5, y1, y2)
|
|
+ " " + x2 + " " + y2
|
|
return s;
|
|
});
|
|
|
|
}
|
|
buildGraph()
|
|
</script>
|
|
"""
|
|
|
|
|
|
def TensorTypeToName(tensor_type):
|
|
"""Converts a numerical enum to a readable tensor type."""
|
|
for name, value in schema_fb.TensorType.__dict__.items():
|
|
if value == tensor_type:
|
|
return name
|
|
return None
|
|
|
|
|
|
def BuiltinCodeToName(code):
|
|
"""Converts a builtin op code enum to a readable name."""
|
|
for name, value in schema_fb.BuiltinOperator.__dict__.items():
|
|
if value == code:
|
|
return name
|
|
return None
|
|
|
|
|
|
def NameListToString(name_list):
|
|
"""Converts a list of integers to the equivalent ASCII string."""
|
|
if isinstance(name_list, str):
|
|
return name_list
|
|
else:
|
|
result = ""
|
|
if name_list is not None:
|
|
for val in name_list:
|
|
result = result + chr(int(val))
|
|
return result
|
|
|
|
|
|
class OpCodeMapper:
|
|
"""Maps an opcode index to an op name."""
|
|
|
|
def __init__(self, data):
|
|
self.code_to_name = {}
|
|
for idx, d in enumerate(data["operator_codes"]):
|
|
self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
|
|
if self.code_to_name[idx] == "CUSTOM":
|
|
self.code_to_name[idx] = NameListToString(d["custom_code"])
|
|
|
|
def __call__(self, x):
|
|
if x not in self.code_to_name:
|
|
s = "<UNKNOWN>"
|
|
else:
|
|
s = self.code_to_name[x]
|
|
return "%s (%d)" % (s, x)
|
|
|
|
|
|
class DataSizeMapper:
|
|
"""For buffers, report the number of bytes."""
|
|
|
|
def __call__(self, x):
|
|
if x is not None:
|
|
return "%d bytes" % len(x)
|
|
else:
|
|
return "--"
|
|
|
|
|
|
class TensorMapper:
|
|
"""Maps a list of tensor indices to a tooltip hoverable indicator of more."""
|
|
|
|
def __init__(self, subgraph_data):
|
|
self.data = subgraph_data
|
|
|
|
def __call__(self, x):
|
|
html = ""
|
|
if x is None:
|
|
return html
|
|
|
|
html += "<span class='tooltip'><span class='tooltipcontent'>"
|
|
for i in x:
|
|
tensor = self.data["tensors"][i]
|
|
html += str(i) + " "
|
|
html += NameListToString(tensor["name"]) + " "
|
|
html += TensorTypeToName(tensor["type"]) + " "
|
|
html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
|
|
html += (repr(tensor["shape_signature"])
|
|
if "shape_signature" in tensor else "[]") + "<br>"
|
|
html += "</span>"
|
|
html += repr(x)
|
|
html += "</span>"
|
|
return html
|
|
|
|
|
|
def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
|
"""Produces the HTML required to have a d3 visualization of the dag."""
|
|
|
|
def TensorName(idx):
|
|
return "t%d" % idx
|
|
|
|
def OpName(idx):
|
|
return "o%d" % idx
|
|
|
|
edges = []
|
|
nodes = []
|
|
first = {}
|
|
second = {}
|
|
pixel_mult = 200 # TODO(aselle): multiplier for initial placement
|
|
width_mult = 170 # TODO(aselle): multiplier for initial placement
|
|
for op_index, op in enumerate(g["operators"] or []):
|
|
if op["inputs"] is not None:
|
|
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
|
|
if tensor_index not in first:
|
|
first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
|
|
(tensor_input_position + 1) * width_mult)
|
|
edges.append({
|
|
"source": TensorName(tensor_index),
|
|
"target": OpName(op_index)
|
|
})
|
|
if op["outputs"] is not None:
|
|
for tensor_output_position, tensor_index in enumerate(op["outputs"]):
|
|
if tensor_index not in second:
|
|
second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
|
|
(tensor_output_position + 1) * width_mult)
|
|
edges.append({
|
|
"target": TensorName(tensor_index),
|
|
"source": OpName(op_index)
|
|
})
|
|
|
|
nodes.append({
|
|
"id": OpName(op_index),
|
|
"name": opcode_mapper(op["opcode_index"]),
|
|
"group": 2,
|
|
"x": pixel_mult,
|
|
"y": (op_index + 1) * pixel_mult
|
|
})
|
|
for tensor_index, tensor in enumerate(g["tensors"]):
|
|
initial_y = (
|
|
first[tensor_index] if tensor_index in first else
|
|
second[tensor_index] if tensor_index in second else (0, 0))
|
|
|
|
nodes.append({
|
|
"id": TensorName(tensor_index),
|
|
"name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
|
|
"group": 1,
|
|
"x": initial_y[1],
|
|
"y": initial_y[0]
|
|
})
|
|
graph_str = json.dumps({"nodes": nodes, "edges": edges})
|
|
|
|
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
|
|
return html
|
|
|
|
|
|
def GenerateTableHtml(items, keys_to_print, display_index=True):
|
|
"""Given a list of object values and keys to print, make an HTML table.
|
|
|
|
Args:
|
|
items: Items to print an array of dicts.
|
|
keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
|
|
items[0][key] should exist. display_fn is the mapping function on display.
|
|
i.e. the displayed html cell will have the string returned by
|
|
`mapping_fn(items[0][key])`.
|
|
display_index: add a column which is the index of each row in `items`.
|
|
|
|
Returns:
|
|
An html table.
|
|
"""
|
|
html = ""
|
|
# Print the list of items
|
|
html += "<table><tr>\n"
|
|
html += "<tr>\n"
|
|
if display_index:
|
|
html += "<th>index</th>"
|
|
for h, mapper in keys_to_print:
|
|
html += "<th>%s</th>" % h
|
|
html += "</tr>\n"
|
|
for idx, tensor in enumerate(items):
|
|
html += "<tr>\n"
|
|
if display_index:
|
|
html += "<td>%d</td>" % idx
|
|
# print tensor.keys()
|
|
for h, mapper in keys_to_print:
|
|
val = tensor[h] if h in tensor else None
|
|
val = val if mapper is None else mapper(val)
|
|
html += "<td>%s</td>\n" % val
|
|
|
|
html += "</tr>\n"
|
|
html += "</table>\n"
|
|
return html
|
|
|
|
|
|
def CamelCaseToSnakeCase(camel_case_input):
|
|
"""Converts an identifier in CamelCase to snake_case."""
|
|
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
|
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
|
|
|
|
def FlatbufferToDict(fb, preserve_as_numpy):
|
|
"""Converts a hierarchy of FB objects into a nested dict.
|
|
|
|
We avoid transforming big parts of the flat buffer into python arrays. This
|
|
speeds conversion from ten minutes to a few seconds on big graphs.
|
|
|
|
Args:
|
|
fb: a flat buffer structure. (i.e. ModelT)
|
|
preserve_as_numpy: true if all downstream np.arrays should be preserved.
|
|
false if all downstream np.array should become python arrays
|
|
Returns:
|
|
A dictionary representing the flatbuffer rather than a flatbuffer object.
|
|
"""
|
|
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
|
|
return fb
|
|
elif hasattr(fb, "__dict__"):
|
|
result = {}
|
|
for attribute_name in dir(fb):
|
|
attribute = fb.__getattribute__(attribute_name)
|
|
if not callable(attribute) and attribute_name[0] != "_":
|
|
snake_name = CamelCaseToSnakeCase(attribute_name)
|
|
preserve = True if attribute_name == "buffers" else preserve_as_numpy
|
|
result[snake_name] = FlatbufferToDict(attribute, preserve)
|
|
return result
|
|
elif isinstance(fb, np.ndarray):
|
|
return fb if preserve_as_numpy else fb.tolist()
|
|
elif hasattr(fb, "__len__"):
|
|
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
|
|
else:
|
|
return fb
|
|
|
|
|
|
def CreateDictFromFlatbuffer(buffer_data):
|
|
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
|
|
model = schema_fb.ModelT.InitFromObj(model_obj)
|
|
return FlatbufferToDict(model, preserve_as_numpy=False)
|
|
|
|
|
|
def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name
|
|
"""Returns html description with the given tflite model.
|
|
|
|
Args:
|
|
tflite_input: TFLite flatbuffer model path or model object.
|
|
input_is_filepath: Tells if tflite_input is a model path or a model object.
|
|
|
|
Returns:
|
|
Dump of the given tflite model in HTML format.
|
|
|
|
Raises:
|
|
RuntimeError: If the input is not valid.
|
|
"""
|
|
|
|
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
|
|
# exist.
|
|
if input_is_filepath:
|
|
if not os.path.exists(tflite_input):
|
|
raise RuntimeError("Invalid filename %r" % tflite_input)
|
|
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
|
|
with open(tflite_input, "rb") as file_handle:
|
|
file_data = bytearray(file_handle.read())
|
|
data = CreateDictFromFlatbuffer(file_data)
|
|
elif tflite_input.endswith(".json"):
|
|
data = json.load(open(tflite_input))
|
|
else:
|
|
raise RuntimeError("Input file was not .tflite or .json")
|
|
else:
|
|
data = CreateDictFromFlatbuffer(tflite_input)
|
|
html = ""
|
|
html += _CSS
|
|
html += "<h1>TensorFlow Lite Model</h2>"
|
|
|
|
data["filename"] = tflite_input if input_is_filepath else (
|
|
"Null (used model object)") # Avoid special case
|
|
|
|
toplevel_stuff = [("filename", None), ("version", None),
|
|
("description", None)]
|
|
|
|
html += "<table>\n"
|
|
for key, mapping in toplevel_stuff:
|
|
if not mapping:
|
|
mapping = lambda x: x
|
|
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key)))
|
|
html += "</table>\n"
|
|
|
|
# Spec on what keys to display
|
|
buffer_keys_to_display = [("data", DataSizeMapper())]
|
|
operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
|
|
("custom_code", NameListToString),
|
|
("version", None)]
|
|
|
|
# Update builtin code fields.
|
|
for d in data["operator_codes"]:
|
|
d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
|
|
|
|
for subgraph_idx, g in enumerate(data["subgraphs"]):
|
|
# Subgraph local specs on what to display
|
|
html += "<div class='subgraph'>"
|
|
tensor_mapper = TensorMapper(g)
|
|
opcode_mapper = OpCodeMapper(data)
|
|
op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
|
|
("builtin_options", None),
|
|
("opcode_index", opcode_mapper)]
|
|
tensor_keys_to_display = [("name", NameListToString),
|
|
("type", TensorTypeToName), ("shape", None),
|
|
("shape_signature", None), ("buffer", None),
|
|
("quantization", None)]
|
|
|
|
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
|
|
|
|
# Inputs and outputs.
|
|
html += "<h3>Inputs/Outputs</h3>\n"
|
|
html += GenerateTableHtml([{
|
|
"inputs": g["inputs"],
|
|
"outputs": g["outputs"]
|
|
}], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
|
|
display_index=False)
|
|
|
|
# Print the tensors.
|
|
html += "<h3>Tensors</h3>\n"
|
|
html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
|
|
|
|
# Print the ops.
|
|
if g["operators"]:
|
|
html += "<h3>Ops</h3>\n"
|
|
html += GenerateTableHtml(g["operators"], op_keys_to_display)
|
|
|
|
# Visual graph.
|
|
html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
|
|
subgraph_idx,)
|
|
html += GenerateGraph(subgraph_idx, g, opcode_mapper)
|
|
html += "</div>"
|
|
|
|
# Buffers have no data, but maybe in the future they will
|
|
html += "<h2>Buffers</h2>\n"
|
|
html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
|
|
|
|
# Operator codes
|
|
html += "<h2>Operator Codes</h2>\n"
|
|
html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
|
|
|
|
html += "</body></html>\n"
|
|
|
|
return html
|
|
|
|
|
|
def main(argv):
|
|
try:
|
|
tflite_input = argv[1]
|
|
html_output = argv[2]
|
|
except IndexError:
|
|
print("Usage: %s <input tflite> <output html>" % (argv[0]))
|
|
else:
|
|
html = create_html(tflite_input)
|
|
with open(html_output, "w") as output_file:
|
|
output_file.write(html)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(sys.argv)
|