46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
|
import numpy as np
|
||
|
from matplotlib import pyplot as plt, gridspec
|
||
|
|
||
|
|
||
|
def plot_two_images(a: np.ndarray, b: np.ndarray):
|
||
|
plt.figure(figsize=[10, 10])
|
||
|
plt.subplot(121)
|
||
|
plt.imshow(a)
|
||
|
plt.title("A")
|
||
|
plt.subplot(122)
|
||
|
plt.imshow(b)
|
||
|
plt.title("B")
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
def plot_results(source, source_anime, results, anime_faces_set, method):
|
||
|
cols = len(results)
|
||
|
plt.figure(figsize=[3*cols, 7])
|
||
|
gs = gridspec.GridSpec(2, cols)
|
||
|
|
||
|
plt.subplot(gs[0, cols // 2 - 1])
|
||
|
plt.imshow(source)
|
||
|
plt.title('Your image')
|
||
|
plt.axis('off')
|
||
|
|
||
|
plt.subplot(gs[0, cols // 2])
|
||
|
plt.imshow(source_anime)
|
||
|
plt.title('Your image in Anime style')
|
||
|
plt.axis('off')
|
||
|
|
||
|
plt.figtext(0.5, 0.525, "Predictions", ha="center", va="top", fontsize=16)
|
||
|
|
||
|
for idx, prediction in enumerate(results):
|
||
|
result_img = anime_faces_set['values'][anime_faces_set['labels'].index(prediction['name'])]
|
||
|
plt.subplot(gs[1, idx])
|
||
|
plt.imshow(result_img, interpolation='bicubic')
|
||
|
plt.title(f'{prediction["name"].partition(".")[0]}, score={str(round(prediction["score"], 4))}')
|
||
|
plt.axis('off')
|
||
|
|
||
|
plt.tight_layout()
|
||
|
|
||
|
plt.figtext(0.5, 0.01, f"Metric: {method}", ha="center", va="bottom", fontsize=12)
|
||
|
plt.subplots_adjust(wspace=0, hspace=0.1)
|
||
|
|
||
|
plt.show()
|