import numpy as np
from matplotlib import pyplot as plt, gridspec


def plot_two_images(a: np.ndarray, b: np.ndarray):
  plt.figure(figsize=[10, 10])
  plt.subplot(121)
  plt.imshow(a)
  plt.title("A")
  plt.subplot(122)
  plt.imshow(b)
  plt.title("B")
  plt.show()


def plot_results(source, source_anime, results, anime_faces_set, method):
    cols = len(results)
    plt.figure(figsize=[3*cols, 7])
    gs = gridspec.GridSpec(2, cols)

    plt.subplot(gs[0, cols // 2 - 1])
    plt.imshow(source)
    plt.title('Your image')
    plt.axis('off')

    plt.subplot(gs[0, cols // 2])
    plt.imshow(source_anime)
    plt.title('Your image in Anime style')
    plt.axis('off')

    plt.figtext(0.5, 0.525, "Predictions", ha="center", va="top", fontsize=16)

    for idx, prediction in enumerate(results):
        result_img = anime_faces_set['values'][anime_faces_set['labels'].index(prediction['name'])]
        plt.subplot(gs[1, idx])
        plt.imshow(result_img, interpolation='bicubic')
        plt.title(f'{prediction["name"].partition(".")[0]}, score={str(round(prediction["score"], 4))}')
        plt.axis('off')

    plt.tight_layout()

    plt.figtext(0.5, 0.01, f"Metric: {method}", ha="center", va="bottom", fontsize=12)
    plt.subplots_adjust(wspace=0, hspace=0.1)

    plt.show()