# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for graph plugin."""

from tensorboard.compat.proto import graph_pb2


def _prefixed_op_name(prefix, op_name):
    return "%s/%s" % (prefix, op_name)


def _prefixed_func_name(prefix, func_name):
    """Returns function name prefixed with `prefix`.

    For function libraries, which are often created out of autographed Python
    function, are factored out in the graph vis. They are grouped under a
    function name which often has a shape of
    `__inference_[py_func_name]_[numeric_suffix]`.

    While it does not have some unique information about which graph it is from,
    creating another wrapping structure with graph prefix and "/" is less than
    ideal so we join the prefix and func_name using underscore.

    TODO(stephanwlee): add business logic to strip "__inference_" for more user
    friendlier name
    """
    return "%s_%s" % (prefix, func_name)


def _add_with_prepended_names(prefix, graph_to_add, destination_graph):
    for node in graph_to_add.node:
        new_node = destination_graph.node.add()
        new_node.CopyFrom(node)
        new_node.name = _prefixed_op_name(prefix, node.name)
        new_node.input[:] = [
            _prefixed_op_name(prefix, input_name) for input_name in node.input
        ]

        # Remap tf.function method name in the PartitionedCall. 'f' is short for
        # function.
        if new_node.op == "PartitionedCall" and new_node.attr["f"]:

            new_node.attr["f"].func.name = _prefixed_func_name(
                prefix,
                new_node.attr["f"].func.name,
            )

    for func in graph_to_add.library.function:
        new_func = destination_graph.library.function.add()
        new_func.CopyFrom(func)
        new_func.signature.name = _prefixed_func_name(
            prefix, new_func.signature.name
        )

    for gradient in graph_to_add.library.gradient:
        new_gradient = destination_graph.library.gradient.add()
        new_gradient.CopyFrom(gradient)
        new_gradient.function_name = _prefixed_func_name(
            prefix,
            new_gradient.function_name,
        )
        new_gradient.gradient_func = _prefixed_func_name(
            prefix,
            new_gradient.gradient_func,
        )


def merge_graph_defs(graph_defs):
    """Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.

    All GraphDefs are expected to be of TensorBoard's.

    When collecting graphs using the `tf.summary.trace` API, node names are not
    guranteed to be unique.  When non-unique names are not considered, it can
    lead to graph visualization showing them as one which creates inaccurate
    depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
    B -> E, you may see {A, D} -> B -> E).  To prevent such graph, we checked
    for uniquenss while merging but it resulted in
    https://github.com/tensorflow/tensorboard/issues/1929.

    To remedy these issues, we simply "apply name scope" on each graph by
    prefixing it with unique name (with a chance of collision) to create
    unconnected group of graphs.

    In case there is only one graph def passed, it returns the original
    graph_def. In case no graph defs are passed, it returns an empty GraphDef.

    Args:
      graph_defs: TensorBoard GraphDefs to merge.

    Returns:
      TensorBoard GraphDef that merges all graph_defs with unique prefixes.

    Raises:
      ValueError in case GraphDef versions mismatch.
    """
    if len(graph_defs) == 1:
        return graph_defs[0]
    elif len(graph_defs) == 0:
        return graph_pb2.GraphDef()

    dst_graph_def = graph_pb2.GraphDef()

    if graph_defs[0].versions.producer:
        dst_graph_def.versions.CopyFrom(graph_defs[0].versions)

    for index, graph_def in enumerate(graph_defs):
        if dst_graph_def.versions.producer != graph_def.versions.producer:
            raise ValueError("Cannot combine GraphDefs of different versions.")

        _add_with_prepended_names(
            "graph_%d" % (index + 1),
            graph_def,
            dst_graph_def,
        )

    return dst_graph_def