# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard metrics plugin."""


import collections
import imghdr
import json

from werkzeug import wrappers

from tensorboard import errors
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.histogram import metadata as histogram_metadata
from tensorboard.plugins.image import metadata as image_metadata
from tensorboard.plugins.metrics import metadata
from tensorboard.plugins.scalar import metadata as scalar_metadata


_IMGHDR_TO_MIMETYPE = {
    "bmp": "image/bmp",
    "gif": "image/gif",
    "jpeg": "image/jpeg",
    "png": "image/png",
    "svg": "image/svg+xml",
}

_DEFAULT_IMAGE_MIMETYPE = "application/octet-stream"

_SINGLE_RUN_PLUGINS = frozenset(
    [histogram_metadata.PLUGIN_NAME, image_metadata.PLUGIN_NAME]
)

_SAMPLED_PLUGINS = frozenset([image_metadata.PLUGIN_NAME])


def _get_tag_description_info(mapping):
    """Gets maps from tags to descriptions, and descriptions to runs.

    Args:
        mapping: a nested map `d` such that `d[run][tag]` is a time series
          produced by DataProvider's `list_*` methods.

    Returns:
        A tuple containing
            tag_to_descriptions: A map from tag strings to a set of description
                strings.
            description_to_runs: A map from description strings to a set of run
                strings.
    """
    tag_to_descriptions = collections.defaultdict(set)
    description_to_runs = collections.defaultdict(set)
    for (run, tag_to_content) in mapping.items():
        for (tag, metadatum) in tag_to_content.items():
            description = metadatum.description
            if len(description):
                tag_to_descriptions[tag].add(description)
                description_to_runs[description].add(run)

    return tag_to_descriptions, description_to_runs


def _build_combined_description(descriptions, description_to_runs):
    """Creates a single description from a set of descriptions.

    Descriptions may be composites when a single tag has different descriptions
    across multiple runs.

    Args:
        descriptions: A list of description strings.
        description_to_runs: A map from description strings to a set of run
            strings.

    Returns:
        The combined description string.
    """
    prefixed_descriptions = []
    for description in descriptions:
        runs = sorted(description_to_runs[description])
        run_or_runs = "runs" if len(runs) > 1 else "run"
        run_header = "## For " + run_or_runs + ": " + ", ".join(runs)
        description_html = run_header + "\n" + description
        prefixed_descriptions.append(description_html)

    header = "# Multiple descriptions\n"
    return header + "\n".join(prefixed_descriptions)


def _get_tag_to_description(mapping):
    """Returns a map of tags to descriptions.

    Args:
        mapping: a nested map `d` such that `d[run][tag]` is a time series
          produced by DataProvider's `list_*` methods.

    Returns:
        A map from tag strings to description HTML strings. E.g.
        {
            "loss": "<h1>Multiple descriptions</h1><h2>For runs: test, train
            </h2><p>...</p>",
            "loss2": "<p>The lossy details</p>",
        }
    """
    tag_to_descriptions, description_to_runs = _get_tag_description_info(
        mapping
    )

    result = {}
    for tag in tag_to_descriptions:
        descriptions = sorted(tag_to_descriptions[tag])
        if len(descriptions) == 1:
            description = descriptions[0]
        else:
            description = _build_combined_description(
                descriptions, description_to_runs
            )
        result[tag] = plugin_util.markdown_to_safe_html(description)

    return result


def _get_run_tag_info(mapping):
    """Returns a map of run names to a list of tag names.

    Args:
        mapping: a nested map `d` such that `d[run][tag]` is a time series
          produced by DataProvider's `list_*` methods.

    Returns:
        A map from run strings to a list of tag strings. E.g.
            {"loss001a": ["actor/loss", "critic/loss"], ...}
    """
    return {run: sorted(mapping[run]) for run in mapping}


def _format_basic_mapping(mapping):
    """Prepares a scalar or histogram mapping for client consumption.

    Args:
        mapping: a nested map `d` such that `d[run][tag]` is a time series
          produced by DataProvider's `list_*` methods.

    Returns:
        A dict with the following fields:
            runTagInfo: the return type of `_get_run_tag_info`
            tagDescriptions: the return type of `_get_tag_to_description`
    """
    return {
        "runTagInfo": _get_run_tag_info(mapping),
        "tagDescriptions": _get_tag_to_description(mapping),
    }


def _format_image_blob_sequence_datum(sorted_datum_list, sample):
    """Formats image metadata from a list of BlobSequenceDatum's for clients.

    This expects that frontend clients need to access images based on the
    run+tag+sample.

    Args:
        sorted_datum_list: a list of DataProvider's `BlobSequenceDatum`, sorted by
            step. This can be produced via DataProvider's `read_blob_sequences`.
        sample: zero-indexed integer for the requested sample.

    Returns:
        A list of `ImageStepDatum` (see http_api.md).
    """
    # For images, ignore the first 2 items of a BlobSequenceDatum's values, which
    # correspond to width, height.
    index = sample + 2
    step_data = []
    for datum in sorted_datum_list:
        if len(datum.values) <= index:
            continue

        step_data.append(
            {
                "step": datum.step,
                "wallTime": datum.wall_time,
                "imageId": datum.values[index].blob_key,
            }
        )
    return step_data


def _get_tag_run_image_info(mapping):
    """Returns a map of tag names to run information.

    Args:
        mapping: the result of DataProvider's `list_blob_sequences`.

    Returns:
        A nested map from run strings to tag string to image info, where image
        info is an object of form {"maxSamplesPerStep": num}. For example,
        {
            "reshaped": {
                "test": {"maxSamplesPerStep": 1},
                "train": {"maxSamplesPerStep": 1}
            },
            "convolved": {"test": {"maxSamplesPerStep": 50}},
        }
    """
    tag_run_image_info = collections.defaultdict(dict)
    for (run, tag_to_content) in mapping.items():
        for (tag, metadatum) in tag_to_content.items():
            tag_run_image_info[tag][run] = {
                "maxSamplesPerStep": metadatum.max_length - 2  # width, height
            }
    return dict(tag_run_image_info)


def _format_image_mapping(mapping):
    """Prepares an image mapping for client consumption.

    Args:
        mapping: the result of DataProvider's `list_blob_sequences`.

    Returns:
        A dict with the following fields:
            tagRunSampledInfo: the return type of `_get_tag_run_image_info`
            tagDescriptions: the return type of `_get_tag_description_info`
    """
    return {
        "tagDescriptions": _get_tag_to_description(mapping),
        "tagRunSampledInfo": _get_tag_run_image_info(mapping),
    }


class MetricsPlugin(base_plugin.TBPlugin):
    """Metrics Plugin for TensorBoard."""

    plugin_name = metadata.PLUGIN_NAME

    def __init__(self, context):
        """Instantiates MetricsPlugin.

        Args:
            context: A base_plugin.TBContext instance. MetricsLoader checks that
                it contains a valid `data_provider`.
        """
        self._data_provider = context.data_provider

        # For histograms, use a round number + 1 since sampling includes both start
        # and end steps, so N+1 samples corresponds to dividing the step sequence
        # into N intervals.
        sampling_hints = context.sampling_hints or {}
        self._plugin_downsampling = {
            "scalars": sampling_hints.get(scalar_metadata.PLUGIN_NAME, 1000),
            "histograms": sampling_hints.get(
                histogram_metadata.PLUGIN_NAME, 51
            ),
            "images": sampling_hints.get(image_metadata.PLUGIN_NAME, 10),
        }
        self._scalar_version_checker = plugin_util._MetadataVersionChecker(
            data_kind="scalar time series",
            latest_known_version=0,
        )
        self._histogram_version_checker = plugin_util._MetadataVersionChecker(
            data_kind="histogram time series",
            latest_known_version=0,
        )
        self._image_version_checker = plugin_util._MetadataVersionChecker(
            data_kind="image time series",
            latest_known_version=0,
        )

    def frontend_metadata(self):
        return base_plugin.FrontendMetadata(
            is_ng_component=True, tab_name="Time Series"
        )

    def get_plugin_apps(self):
        return {
            "/tags": self._serve_tags,
            "/timeSeries": self._serve_time_series,
            "/imageData": self._serve_image_data,
        }

    def data_plugin_names(self):
        return (
            scalar_metadata.PLUGIN_NAME,
            histogram_metadata.PLUGIN_NAME,
            image_metadata.PLUGIN_NAME,
        )

    def is_active(self):
        return False  # 'data_plugin_names' suffices.

    @wrappers.Request.application
    def _serve_tags(self, request):
        ctx = plugin_util.context(request.environ)
        experiment = plugin_util.experiment_id(request.environ)
        index = self._tags_impl(ctx, experiment=experiment)
        return http_util.Respond(request, index, "application/json")

    def _tags_impl(self, ctx, experiment=None):
        """Returns tag metadata for a given experiment's logged metrics.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            experiment: optional string ID of the request's experiment.

        Returns:
            A nested dict 'd' with keys in ("scalars", "histograms", "images")
                and values being the return type of _format_*mapping.
        """
        scalar_mapping = self._data_provider.list_scalars(
            ctx,
            experiment_id=experiment,
            plugin_name=scalar_metadata.PLUGIN_NAME,
        )
        scalar_mapping = self._filter_by_version(
            scalar_mapping,
            scalar_metadata.parse_plugin_metadata,
            self._scalar_version_checker,
        )

        histogram_mapping = self._data_provider.list_tensors(
            ctx,
            experiment_id=experiment,
            plugin_name=histogram_metadata.PLUGIN_NAME,
        )
        if histogram_mapping is None:
            histogram_mapping = {}
        histogram_mapping = self._filter_by_version(
            histogram_mapping,
            histogram_metadata.parse_plugin_metadata,
            self._histogram_version_checker,
        )

        image_mapping = self._data_provider.list_blob_sequences(
            ctx,
            experiment_id=experiment,
            plugin_name=image_metadata.PLUGIN_NAME,
        )
        if image_mapping is None:
            image_mapping = {}
        image_mapping = self._filter_by_version(
            image_mapping,
            image_metadata.parse_plugin_metadata,
            self._image_version_checker,
        )

        result = {}
        result["scalars"] = _format_basic_mapping(scalar_mapping)
        result["histograms"] = _format_basic_mapping(histogram_mapping)
        result["images"] = _format_image_mapping(image_mapping)
        return result

    def _filter_by_version(self, mapping, parse_metadata, version_checker):
        """Filter `DataProvider.list_*` output by summary metadata version."""
        result = {run: {} for run in mapping}
        for (run, tag_to_content) in mapping.items():
            for (tag, metadatum) in tag_to_content.items():
                md = parse_metadata(metadatum.plugin_content)
                if not version_checker.ok(md.version, run, tag):
                    continue
                result[run][tag] = metadatum
        return result

    @wrappers.Request.application
    def _serve_time_series(self, request):
        ctx = plugin_util.context(request.environ)
        experiment = plugin_util.experiment_id(request.environ)
        if request.method == "POST":
            series_requests_string = request.form.get("requests")
        else:
            series_requests_string = request.args.get("requests")
        if not series_requests_string:
            raise errors.InvalidArgumentError("Missing 'requests' field")
        try:
            series_requests = json.loads(series_requests_string)
        except ValueError:
            raise errors.InvalidArgumentError(
                "Unable to parse 'requests' as JSON"
            )

        response = self._time_series_impl(ctx, experiment, series_requests)
        return http_util.Respond(request, response, "application/json")

    def _time_series_impl(self, ctx, experiment, series_requests):
        """Constructs a list of responses from a list of series requests.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            experiment: string ID of the request's experiment.
            series_requests: a list of `TimeSeriesRequest` dicts (see http_api.md).

        Returns:
            A list of `TimeSeriesResponse` dicts (see http_api.md).
        """
        responses = [
            self._get_time_series(ctx, experiment, request)
            for request in series_requests
        ]
        return responses

    def _create_base_response(self, series_request):
        tag = series_request.get("tag")
        run = series_request.get("run")
        plugin = series_request.get("plugin")
        sample = series_request.get("sample")
        response = {"plugin": plugin, "tag": tag}
        if isinstance(run, str):
            response["run"] = run
        if isinstance(sample, int):
            response["sample"] = sample

        return response

    def _get_invalid_request_error(self, series_request):
        tag = series_request.get("tag")
        plugin = series_request.get("plugin")
        run = series_request.get("run")
        sample = series_request.get("sample")

        if not isinstance(tag, str):
            return "Missing tag"

        if (
            plugin != scalar_metadata.PLUGIN_NAME
            and plugin != histogram_metadata.PLUGIN_NAME
            and plugin != image_metadata.PLUGIN_NAME
        ):
            return "Invalid plugin"

        if plugin in _SINGLE_RUN_PLUGINS and not isinstance(run, str):
            return "Missing run"

        if plugin in _SAMPLED_PLUGINS and not isinstance(sample, int):
            return "Missing sample"

        return None

    def _get_time_series(self, ctx, experiment, series_request):
        """Returns time series data for a given tag, plugin.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            experiment: string ID of the request's experiment.
            series_request: a `TimeSeriesRequest` (see http_api.md).

        Returns:
            A `TimeSeriesResponse` dict (see http_api.md).
        """
        tag = series_request.get("tag")
        run = series_request.get("run")
        plugin = series_request.get("plugin")
        sample = series_request.get("sample")
        response = self._create_base_response(series_request)
        request_error = self._get_invalid_request_error(series_request)
        if request_error:
            response["error"] = request_error
            return response

        runs = [run] if run else None
        run_to_series = None
        if plugin == scalar_metadata.PLUGIN_NAME:
            run_to_series = self._get_run_to_scalar_series(
                ctx, experiment, tag, runs
            )

        if plugin == histogram_metadata.PLUGIN_NAME:
            run_to_series = self._get_run_to_histogram_series(
                ctx, experiment, tag, runs
            )

        if plugin == image_metadata.PLUGIN_NAME:
            run_to_series = self._get_run_to_image_series(
                ctx, experiment, tag, sample, runs
            )

        response["runToSeries"] = run_to_series
        return response

    def _get_run_to_scalar_series(self, ctx, experiment, tag, runs):
        """Builds a run-to-scalar-series dict for client consumption.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            experiment: a string experiment id.
            tag: string of the requested tag.
            runs: optional list of run names as strings.

        Returns:
            A map from string run names to `ScalarStepDatum` (see http_api.md).
        """
        mapping = self._data_provider.read_scalars(
            ctx,
            experiment_id=experiment,
            plugin_name=scalar_metadata.PLUGIN_NAME,
            downsample=self._plugin_downsampling["scalars"],
            run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]),
        )

        run_to_series = {}
        for (result_run, tag_data) in mapping.items():
            if tag not in tag_data:
                continue
            values = [
                {
                    "wallTime": datum.wall_time,
                    "step": datum.step,
                    "value": datum.value,
                }
                for datum in tag_data[tag]
            ]
            run_to_series[result_run] = values

        return run_to_series

    def _format_histogram_datum_bins(self, datum):
        """Formats a histogram datum's bins for client consumption.

        Args:
            datum: a DataProvider's TensorDatum.

        Returns:
            A list of `HistogramBin`s (see http_api.md).
        """
        numpy_list = datum.numpy.tolist()
        bins = [{"min": x[0], "max": x[1], "count": x[2]} for x in numpy_list]
        return bins

    def _get_run_to_histogram_series(self, ctx, experiment, tag, runs):
        """Builds a run-to-histogram-series dict for client consumption.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            experiment: a string experiment id.
            tag: string of the requested tag.
            runs: optional list of run names as strings.

        Returns:
            A map from string run names to `HistogramStepDatum` (see http_api.md).
        """
        mapping = self._data_provider.read_tensors(
            ctx,
            experiment_id=experiment,
            plugin_name=histogram_metadata.PLUGIN_NAME,
            downsample=self._plugin_downsampling["histograms"],
            run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]),
        )

        run_to_series = {}
        for (result_run, tag_data) in mapping.items():
            if tag not in tag_data:
                continue
            values = [
                {
                    "wallTime": datum.wall_time,
                    "step": datum.step,
                    "bins": self._format_histogram_datum_bins(datum),
                }
                for datum in tag_data[tag]
            ]
            run_to_series[result_run] = values

        return run_to_series

    def _get_run_to_image_series(self, ctx, experiment, tag, sample, runs):
        """Builds a run-to-image-series dict for client consumption.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            experiment: a string experiment id.
            tag: string of the requested tag.
            sample: zero-indexed integer for the requested sample.
            runs: optional list of run names as strings.

        Returns:
            A `RunToSeries` dict (see http_api.md).
        """
        mapping = self._data_provider.read_blob_sequences(
            ctx,
            experiment_id=experiment,
            plugin_name=image_metadata.PLUGIN_NAME,
            downsample=self._plugin_downsampling["images"],
            run_tag_filter=provider.RunTagFilter(runs, tags=[tag]),
        )

        run_to_series = {}
        for (result_run, tag_data) in mapping.items():
            if tag not in tag_data:
                continue
            blob_sequence_datum_list = tag_data[tag]
            series = _format_image_blob_sequence_datum(
                blob_sequence_datum_list, sample
            )
            if series:
                run_to_series[result_run] = series

        return run_to_series

    @wrappers.Request.application
    def _serve_image_data(self, request):
        """Serves an individual image."""
        ctx = plugin_util.context(request.environ)
        blob_key = request.args["imageId"]
        if not blob_key:
            raise errors.InvalidArgumentError("Missing 'imageId' field")

        (data, content_type) = self._image_data_impl(ctx, blob_key)
        return http_util.Respond(request, data, content_type)

    def _image_data_impl(self, ctx, blob_key):
        """Gets the image data for a blob key.

        Args:
            ctx: A `tensorboard.context.RequestContext` value.
            blob_key: a string identifier for a DataProvider blob.

        Returns:
            A tuple containing:
              data: a raw bytestring of the requested image's contents.
              content_type: a string HTTP content type.
        """
        data = self._data_provider.read_blob(ctx, blob_key=blob_key)
        image_type = imghdr.what(None, data)
        content_type = _IMGHDR_TO_MIMETYPE.get(
            image_type, _DEFAULT_IMAGE_MIMETYPE
        )
        return (data, content_type)