Plot results #6
20
main.py
20
main.py
@ -2,10 +2,10 @@ import argparse
|
|||||||
import sys
|
import sys
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
|
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
|
||||||
from load_test_data import load_data
|
from load_test_data import load_data
|
||||||
|
from plots import plot_two_images, plot_results
|
||||||
|
|
||||||
# Allows imports from the style transfer submodule
|
# Allows imports from the style transfer submodule
|
||||||
sys.path.append('DCT-Net')
|
sys.path.append('DCT-Net')
|
||||||
@ -27,17 +27,6 @@ def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcasca
|
|||||||
return face
|
return face
|
||||||
|
|
||||||
|
|
||||||
def plot_two_images(a: np.ndarray, b: np.ndarray):
|
|
||||||
plt.figure(figsize=[10, 10])
|
|
||||||
plt.subplot(121)
|
|
||||||
plt.imshow(a)
|
|
||||||
plt.title("A")
|
|
||||||
plt.subplot(122)
|
|
||||||
plt.imshow(b)
|
|
||||||
plt.title("B")
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
|
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
|
||||||
all_metrics = []
|
all_metrics = []
|
||||||
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
|
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
|
||||||
@ -86,7 +75,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-v', '--validate_only')
|
parser.add_argument('-v', '--validate_only')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
anime_faces_set = load_data('data/images')
|
anime_faces_set = load_data('data/images', (256, 256))
|
||||||
|
|
||||||
if args.validate_only:
|
if args.validate_only:
|
||||||
print('Validating')
|
print('Validating')
|
||||||
@ -103,4 +92,7 @@ if __name__ == '__main__':
|
|||||||
source_anime = transfer_to_anime(source)
|
source_anime = transfer_to_anime(source)
|
||||||
source_face_anime = find_and_crop_face(source_anime)
|
source_face_anime = find_and_crop_face(source_anime)
|
||||||
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
|
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
|
||||||
print(get_top_results(results, count=5))
|
method = 'structural-similarity'
|
||||||
|
top_results = get_top_results(results, count=4, metric=method)
|
||||||
|
print(top_results)
|
||||||
|
plot_results(source, source_anime, top_results, anime_faces_set, method)
|
||||||
|
45
plots.py
Normal file
45
plots.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt, gridspec
|
||||||
|
|
||||||
|
|
||||||
|
def plot_two_images(a: np.ndarray, b: np.ndarray):
|
||||||
|
plt.figure(figsize=[10, 10])
|
||||||
|
plt.subplot(121)
|
||||||
|
plt.imshow(a)
|
||||||
|
plt.title("A")
|
||||||
|
plt.subplot(122)
|
||||||
|
plt.imshow(b)
|
||||||
|
plt.title("B")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results(source, source_anime, results, anime_faces_set, method):
|
||||||
|
cols = len(results)
|
||||||
|
plt.figure(figsize=[3*cols, 7])
|
||||||
|
gs = gridspec.GridSpec(2, cols)
|
||||||
|
|
||||||
|
plt.subplot(gs[0, cols // 2 - 1])
|
||||||
|
plt.imshow(source)
|
||||||
|
plt.title('Your image')
|
||||||
|
plt.axis('off')
|
||||||
|
|
||||||
|
plt.subplot(gs[0, cols // 2])
|
||||||
|
plt.imshow(source_anime)
|
||||||
|
plt.title('Your image in Anime style')
|
||||||
|
plt.axis('off')
|
||||||
|
|
||||||
|
plt.figtext(0.5, 0.525, "Predictions", ha="center", va="top", fontsize=16)
|
||||||
|
|
||||||
|
for idx, prediction in enumerate(results):
|
||||||
|
result_img = anime_faces_set['values'][anime_faces_set['labels'].index(prediction['name'])]
|
||||||
|
plt.subplot(gs[1, idx])
|
||||||
|
plt.imshow(result_img, interpolation='bicubic')
|
||||||
|
plt.title(f'{prediction["name"].partition(".")[0]}, score={str(round(prediction["score"], 4))}')
|
||||||
|
plt.axis('off')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.figtext(0.5, 0.01, f"Metric: {method}", ha="center", va="bottom", fontsize=12)
|
||||||
|
plt.subplots_adjust(wspace=0, hspace=0.1)
|
||||||
|
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue
Block a user