283 lines
11 KiB
Python
283 lines
11 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utilities for handling Keras model in graph plugin.
|
||
|
|
||
|
Two canonical types of Keras model are Functional and Sequential.
|
||
|
A model can be serialized as JSON and deserialized to reconstruct a model.
|
||
|
This utility helps with dealing with the serialized Keras model.
|
||
|
|
||
|
They have distinct structures to the configurations in shapes below:
|
||
|
Functional:
|
||
|
config
|
||
|
name: Name of the model. If not specified, it is 'model' with
|
||
|
an optional suffix if there are more than one instance.
|
||
|
input_layers: Keras.layers.Inputs in the model.
|
||
|
output_layers: Layer names that are outputs of the model.
|
||
|
layers: list of layer configurations.
|
||
|
layer: [*]
|
||
|
inbound_nodes: inputs to this layer.
|
||
|
|
||
|
Sequential:
|
||
|
config
|
||
|
name: Name of the model. If not specified, it is 'sequential' with
|
||
|
an optional suffix if there are more than one instance.
|
||
|
layers: list of layer configurations.
|
||
|
layer: [*]
|
||
|
|
||
|
[*]: Note that a model can be a layer.
|
||
|
Please refer to https://github.com/tensorflow/tfjs-layers/blob/master/src/keras_format/model_serialization.ts
|
||
|
for more complete definition.
|
||
|
"""
|
||
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
||
|
from tensorboard.compat.tensorflow_stub import dtypes
|
||
|
|
||
|
|
||
|
def _walk_layers(keras_layer):
|
||
|
"""Walks the nested keras layer configuration in preorder.
|
||
|
|
||
|
Args:
|
||
|
keras_layer: Keras configuration from model.to_json.
|
||
|
|
||
|
Yields:
|
||
|
A tuple of (name_scope, layer_config).
|
||
|
name_scope: a string representing a scope name, similar to that of tf.name_scope.
|
||
|
layer_config: a dict representing a Keras layer configuration.
|
||
|
"""
|
||
|
yield ("", keras_layer)
|
||
|
if keras_layer.get("config").get("layers"):
|
||
|
name_scope = keras_layer.get("config").get("name")
|
||
|
for layer in keras_layer.get("config").get("layers"):
|
||
|
for (sub_name_scope, sublayer) in _walk_layers(layer):
|
||
|
sub_name_scope = (
|
||
|
"%s/%s" % (name_scope, sub_name_scope)
|
||
|
if sub_name_scope
|
||
|
else name_scope
|
||
|
)
|
||
|
yield (sub_name_scope, sublayer)
|
||
|
|
||
|
|
||
|
def _scoped_name(name_scope, node_name):
|
||
|
"""Returns scoped name for a node as a string in the form '<scope>/<node
|
||
|
name>'.
|
||
|
|
||
|
Args:
|
||
|
name_scope: a string representing a scope name, similar to that of tf.name_scope.
|
||
|
node_name: a string representing the current node name.
|
||
|
|
||
|
Returns
|
||
|
A string representing a scoped name.
|
||
|
"""
|
||
|
if name_scope:
|
||
|
return "%s/%s" % (name_scope, node_name)
|
||
|
return node_name
|
||
|
|
||
|
|
||
|
def _is_model(layer):
|
||
|
"""Returns True if layer is a model.
|
||
|
|
||
|
Args:
|
||
|
layer: a dict representing a Keras model configuration.
|
||
|
|
||
|
Returns:
|
||
|
bool: True if layer is a model.
|
||
|
"""
|
||
|
return layer.get("config").get("layers") is not None
|
||
|
|
||
|
|
||
|
def _norm_to_list_of_layers(maybe_layers):
|
||
|
"""Normalizes to a list of layers.
|
||
|
|
||
|
Args:
|
||
|
maybe_layers: A list of data[1] or a list of list of data.
|
||
|
|
||
|
Returns:
|
||
|
List of list of data.
|
||
|
|
||
|
[1]: A Functional model has fields 'inbound_nodes' and 'output_layers' which can
|
||
|
look like below:
|
||
|
- ['in_layer_name', 0, 0]
|
||
|
- [['in_layer_is_model', 1, 0], ['in_layer_is_model', 1, 1]]
|
||
|
The data inside the list seems to describe [name, size, index].
|
||
|
"""
|
||
|
return (
|
||
|
maybe_layers if isinstance(maybe_layers[0], (list,)) else [maybe_layers]
|
||
|
)
|
||
|
|
||
|
|
||
|
def _update_dicts(
|
||
|
name_scope,
|
||
|
model_layer,
|
||
|
input_to_in_layer,
|
||
|
model_name_to_output,
|
||
|
prev_node_name,
|
||
|
):
|
||
|
"""Updates input_to_in_layer, model_name_to_output, and prev_node_name
|
||
|
based on the model_layer.
|
||
|
|
||
|
Args:
|
||
|
name_scope: a string representing a scope name, similar to that of tf.name_scope.
|
||
|
model_layer: a dict representing a Keras model configuration.
|
||
|
input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
|
||
|
model_name_to_output: a dict mapping Keras Model name to output layer of the model.
|
||
|
prev_node_name: a string representing a previous, in sequential model layout,
|
||
|
node name.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of (input_to_in_layer, model_name_to_output, prev_node_name).
|
||
|
input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
|
||
|
model_name_to_output: a dict mapping Keras Model name to output layer of the model.
|
||
|
prev_node_name: a string representing a previous, in sequential model layout,
|
||
|
node name.
|
||
|
"""
|
||
|
layer_config = model_layer.get("config")
|
||
|
if not layer_config.get("layers"):
|
||
|
raise ValueError("layer is not a model.")
|
||
|
|
||
|
node_name = _scoped_name(name_scope, layer_config.get("name"))
|
||
|
input_layers = layer_config.get("input_layers")
|
||
|
output_layers = layer_config.get("output_layers")
|
||
|
inbound_nodes = model_layer.get("inbound_nodes")
|
||
|
|
||
|
is_functional_model = bool(input_layers and output_layers)
|
||
|
# In case of [1] and the parent model is functional, current layer
|
||
|
# will have the 'inbound_nodes' property.
|
||
|
is_parent_functional_model = bool(inbound_nodes)
|
||
|
|
||
|
if is_parent_functional_model and is_functional_model:
|
||
|
for (input_layer, inbound_node) in zip(input_layers, inbound_nodes):
|
||
|
input_layer_name = _scoped_name(node_name, input_layer)
|
||
|
inbound_node_name = _scoped_name(name_scope, inbound_node[0])
|
||
|
input_to_in_layer[input_layer_name] = inbound_node_name
|
||
|
elif is_parent_functional_model and not is_functional_model:
|
||
|
# Sequential model can take only one input. Make sure inbound to the
|
||
|
# model is linked to the first layer in the Sequential model.
|
||
|
prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0][0])
|
||
|
elif (
|
||
|
not is_parent_functional_model
|
||
|
and prev_node_name
|
||
|
and is_functional_model
|
||
|
):
|
||
|
assert len(input_layers) == 1, (
|
||
|
"Cannot have multi-input Functional model when parent model "
|
||
|
"is not Functional. Number of input layers: %d" % len(input_layer)
|
||
|
)
|
||
|
input_layer = input_layers[0]
|
||
|
input_layer_name = _scoped_name(node_name, input_layer)
|
||
|
input_to_in_layer[input_layer_name] = prev_node_name
|
||
|
|
||
|
if is_functional_model and output_layers:
|
||
|
layers = _norm_to_list_of_layers(output_layers)
|
||
|
layer_names = [_scoped_name(node_name, layer[0]) for layer in layers]
|
||
|
model_name_to_output[node_name] = layer_names
|
||
|
else:
|
||
|
last_layer = layer_config.get("layers")[-1]
|
||
|
last_layer_name = last_layer.get("config").get("name")
|
||
|
output_node = _scoped_name(node_name, last_layer_name)
|
||
|
model_name_to_output[node_name] = [output_node]
|
||
|
return (input_to_in_layer, model_name_to_output, prev_node_name)
|
||
|
|
||
|
|
||
|
def keras_model_to_graph_def(keras_layer):
|
||
|
"""Returns a GraphDef representation of the Keras model in a dict form.
|
||
|
|
||
|
Note that it only supports models that implemented to_json().
|
||
|
|
||
|
Args:
|
||
|
keras_layer: A dict from Keras model.to_json().
|
||
|
|
||
|
Returns:
|
||
|
A GraphDef representation of the layers in the model.
|
||
|
"""
|
||
|
input_to_layer = {}
|
||
|
model_name_to_output = {}
|
||
|
g = GraphDef()
|
||
|
|
||
|
# Sequential model layers do not have a field "inbound_nodes" but
|
||
|
# instead are defined implicitly via order of layers.
|
||
|
prev_node_name = None
|
||
|
|
||
|
for (name_scope, layer) in _walk_layers(keras_layer):
|
||
|
if _is_model(layer):
|
||
|
(
|
||
|
input_to_layer,
|
||
|
model_name_to_output,
|
||
|
prev_node_name,
|
||
|
) = _update_dicts(
|
||
|
name_scope,
|
||
|
layer,
|
||
|
input_to_layer,
|
||
|
model_name_to_output,
|
||
|
prev_node_name,
|
||
|
)
|
||
|
continue
|
||
|
|
||
|
layer_config = layer.get("config")
|
||
|
node_name = _scoped_name(name_scope, layer_config.get("name"))
|
||
|
|
||
|
node_def = g.node.add()
|
||
|
node_def.name = node_name
|
||
|
|
||
|
if layer.get("class_name") is not None:
|
||
|
keras_cls_name = layer.get("class_name").encode("ascii")
|
||
|
node_def.attr["keras_class"].s = keras_cls_name
|
||
|
|
||
|
dtype_or_policy = layer_config.get("dtype")
|
||
|
# Skip dtype processing if this is a dict, since it's presumably a instance of
|
||
|
# tf/keras/mixed_precision/Policy rather than a single dtype.
|
||
|
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
|
||
|
if dtype_or_policy is not None and not isinstance(
|
||
|
dtype_or_policy, dict
|
||
|
):
|
||
|
tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
|
||
|
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
|
||
|
if layer.get("inbound_nodes") is not None:
|
||
|
for maybe_inbound_node in layer.get("inbound_nodes"):
|
||
|
inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
|
||
|
for [name, size, index, _] in inbound_nodes:
|
||
|
inbound_name = _scoped_name(name_scope, name)
|
||
|
# An input to a layer can be output from a model. In that case, the name
|
||
|
# of inbound_nodes to a layer is a name of a model. Remap the name of the
|
||
|
# model to output layer of the model. Also, since there can be multiple
|
||
|
# outputs in a model, make sure we pick the right output_layer from the model.
|
||
|
inbound_node_names = model_name_to_output.get(
|
||
|
inbound_name, [inbound_name]
|
||
|
)
|
||
|
# There can be multiple inbound_nodes that reference the
|
||
|
# same upstream layer. This causes issues when looking for
|
||
|
# a particular index in that layer, since the indices
|
||
|
# captured in `inbound_nodes` doesn't necessarily match the
|
||
|
# number of entries in the `inbound_node_names` list. To
|
||
|
# avoid IndexErrors, we just use the last element in the
|
||
|
# `inbound_node_names` in this situation.
|
||
|
# Note that this is a quick hack to avoid IndexErrors in
|
||
|
# this situation, and might not be an appropriate solution
|
||
|
# to this problem in general.
|
||
|
input_name = (
|
||
|
inbound_node_names[index]
|
||
|
if index < len(inbound_node_names)
|
||
|
else inbound_node_names[-1]
|
||
|
)
|
||
|
node_def.input.append(input_name)
|
||
|
elif prev_node_name is not None:
|
||
|
node_def.input.append(prev_node_name)
|
||
|
|
||
|
if node_name in input_to_layer:
|
||
|
node_def.input.append(input_to_layer.get(node_name))
|
||
|
|
||
|
prev_node_name = node_def.name
|
||
|
|
||
|
return g
|