# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for handling Keras model in graph plugin.

Two canonical types of Keras model are Functional and Sequential.
A model can be serialized as JSON and deserialized to reconstruct a model.
This utility helps with dealing with the serialized Keras model.

They have distinct structures to the configurations in shapes below:
Functional:
  config
    name: Name of the model. If not specified, it is 'model' with
          an optional suffix if there are more than one instance.
    input_layers: Keras.layers.Inputs in the model.
    output_layers: Layer names that are outputs of the model.
    layers: list of layer configurations.
      layer: [*]
        inbound_nodes: inputs to this layer.

Sequential:
  config
    name: Name of the model. If not specified, it is 'sequential' with
          an optional suffix if there are more than one instance.
    layers: list of layer configurations.
      layer: [*]

[*]: Note that a model can be a layer.
Please refer to https://github.com/tensorflow/tfjs-layers/blob/master/src/keras_format/model_serialization.ts
for more complete definition.
"""
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.tensorflow_stub import dtypes


def _walk_layers(keras_layer):
    """Walks the nested keras layer configuration in preorder.

    Args:
      keras_layer: Keras configuration from model.to_json.

    Yields:
      A tuple of (name_scope, layer_config).
      name_scope: a string representing a scope name, similar to that of tf.name_scope.
      layer_config: a dict representing a Keras layer configuration.
    """
    yield ("", keras_layer)
    if keras_layer.get("config").get("layers"):
        name_scope = keras_layer.get("config").get("name")
        for layer in keras_layer.get("config").get("layers"):
            for (sub_name_scope, sublayer) in _walk_layers(layer):
                sub_name_scope = (
                    "%s/%s" % (name_scope, sub_name_scope)
                    if sub_name_scope
                    else name_scope
                )
                yield (sub_name_scope, sublayer)


def _scoped_name(name_scope, node_name):
    """Returns scoped name for a node as a string in the form '<scope>/<node
    name>'.

    Args:
      name_scope: a string representing a scope name, similar to that of tf.name_scope.
      node_name: a string representing the current node name.

    Returns
      A string representing a scoped name.
    """
    if name_scope:
        return "%s/%s" % (name_scope, node_name)
    return node_name


def _is_model(layer):
    """Returns True if layer is a model.

    Args:
      layer: a dict representing a Keras model configuration.

    Returns:
      bool: True if layer is a model.
    """
    return layer.get("config").get("layers") is not None


def _norm_to_list_of_layers(maybe_layers):
    """Normalizes to a list of layers.

    Args:
      maybe_layers: A list of data[1] or a list of list of data.

    Returns:
      List of list of data.

    [1]: A Functional model has fields 'inbound_nodes' and 'output_layers' which can
    look like below:
    - ['in_layer_name', 0, 0]
    - [['in_layer_is_model', 1, 0], ['in_layer_is_model', 1, 1]]
    The data inside the list seems to describe [name, size, index].
    """
    return (
        maybe_layers if isinstance(maybe_layers[0], (list,)) else [maybe_layers]
    )


def _update_dicts(
    name_scope,
    model_layer,
    input_to_in_layer,
    model_name_to_output,
    prev_node_name,
):
    """Updates input_to_in_layer, model_name_to_output, and prev_node_name
    based on the model_layer.

    Args:
      name_scope: a string representing a scope name, similar to that of tf.name_scope.
      model_layer: a dict representing a Keras model configuration.
      input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
      model_name_to_output: a dict mapping Keras Model name to output layer of the model.
      prev_node_name: a string representing a previous, in sequential model layout,
                      node name.

    Returns:
      A tuple of (input_to_in_layer, model_name_to_output, prev_node_name).
      input_to_in_layer: a dict mapping Keras.layers.Input to inbound layer.
      model_name_to_output: a dict mapping Keras Model name to output layer of the model.
      prev_node_name: a string representing a previous, in sequential model layout,
                      node name.
    """
    layer_config = model_layer.get("config")
    if not layer_config.get("layers"):
        raise ValueError("layer is not a model.")

    node_name = _scoped_name(name_scope, layer_config.get("name"))
    input_layers = layer_config.get("input_layers")
    output_layers = layer_config.get("output_layers")
    inbound_nodes = model_layer.get("inbound_nodes")

    is_functional_model = bool(input_layers and output_layers)
    # In case of [1] and the parent model is functional, current layer
    # will have the 'inbound_nodes' property.
    is_parent_functional_model = bool(inbound_nodes)

    if is_parent_functional_model and is_functional_model:
        for (input_layer, inbound_node) in zip(input_layers, inbound_nodes):
            input_layer_name = _scoped_name(node_name, input_layer)
            inbound_node_name = _scoped_name(name_scope, inbound_node[0])
            input_to_in_layer[input_layer_name] = inbound_node_name
    elif is_parent_functional_model and not is_functional_model:
        # Sequential model can take only one input. Make sure inbound to the
        # model is linked to the first layer in the Sequential model.
        prev_node_name = _scoped_name(name_scope, inbound_nodes[0][0][0])
    elif (
        not is_parent_functional_model
        and prev_node_name
        and is_functional_model
    ):
        assert len(input_layers) == 1, (
            "Cannot have multi-input Functional model when parent model "
            "is not Functional. Number of input layers: %d" % len(input_layer)
        )
        input_layer = input_layers[0]
        input_layer_name = _scoped_name(node_name, input_layer)
        input_to_in_layer[input_layer_name] = prev_node_name

    if is_functional_model and output_layers:
        layers = _norm_to_list_of_layers(output_layers)
        layer_names = [_scoped_name(node_name, layer[0]) for layer in layers]
        model_name_to_output[node_name] = layer_names
    else:
        last_layer = layer_config.get("layers")[-1]
        last_layer_name = last_layer.get("config").get("name")
        output_node = _scoped_name(node_name, last_layer_name)
        model_name_to_output[node_name] = [output_node]
    return (input_to_in_layer, model_name_to_output, prev_node_name)


def keras_model_to_graph_def(keras_layer):
    """Returns a GraphDef representation of the Keras model in a dict form.

    Note that it only supports models that implemented to_json().

    Args:
      keras_layer: A dict from Keras model.to_json().

    Returns:
      A GraphDef representation of the layers in the model.
    """
    input_to_layer = {}
    model_name_to_output = {}
    g = GraphDef()

    # Sequential model layers do not have a field "inbound_nodes" but
    # instead are defined implicitly via order of layers.
    prev_node_name = None

    for (name_scope, layer) in _walk_layers(keras_layer):
        if _is_model(layer):
            (
                input_to_layer,
                model_name_to_output,
                prev_node_name,
            ) = _update_dicts(
                name_scope,
                layer,
                input_to_layer,
                model_name_to_output,
                prev_node_name,
            )
            continue

        layer_config = layer.get("config")
        node_name = _scoped_name(name_scope, layer_config.get("name"))

        node_def = g.node.add()
        node_def.name = node_name

        if layer.get("class_name") is not None:
            keras_cls_name = layer.get("class_name").encode("ascii")
            node_def.attr["keras_class"].s = keras_cls_name

        dtype_or_policy = layer_config.get("dtype")
        # Skip dtype processing if this is a dict, since it's presumably a instance of
        # tf/keras/mixed_precision/Policy rather than a single dtype.
        # TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
        if dtype_or_policy is not None and not isinstance(
            dtype_or_policy, dict
        ):
            tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
            node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
        if layer.get("inbound_nodes") is not None:
            for maybe_inbound_node in layer.get("inbound_nodes"):
                inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
                for [name, size, index, _] in inbound_nodes:
                    inbound_name = _scoped_name(name_scope, name)
                    # An input to a layer can be output from a model. In that case, the name
                    # of inbound_nodes to a layer is a name of a model. Remap the name of the
                    # model to output layer of the model. Also, since there can be multiple
                    # outputs in a model, make sure we pick the right output_layer from the model.
                    inbound_node_names = model_name_to_output.get(
                        inbound_name, [inbound_name]
                    )
                    # There can be multiple inbound_nodes that reference the
                    # same upstream layer. This causes issues when looking for
                    # a particular index in that layer, since the indices
                    # captured in `inbound_nodes` doesn't necessarily match the
                    # number of entries in the `inbound_node_names` list. To
                    # avoid IndexErrors, we just use the last element in the
                    # `inbound_node_names` in this situation.
                    # Note that this is a quick hack to avoid IndexErrors in
                    # this situation, and might not be an appropriate solution
                    # to this problem in general.
                    input_name = (
                        inbound_node_names[index]
                        if index < len(inbound_node_names)
                        else inbound_node_names[-1]
                    )
                    node_def.input.append(input_name)
        elif prev_node_name is not None:
            node_def.input.append(prev_node_name)

        if node_name in input_to_layer:
            node_def.input.append(input_to_layer.get(node_name))

        prev_node_name = node_def.name

    return g