diff --git a/load_test_data.py b/load_test_data.py index b5a635b..300cd25 100644 --- a/load_test_data.py +++ b/load_test_data.py @@ -29,7 +29,6 @@ def load_data(input_dir, newSize=(64,64)): p = image_path / n img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray - img = img / 255 # type: ignore #normalizacja test_img.append(img) labels.append(n) diff --git a/main.py b/main.py index af5ad1b..2513b12 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,11 @@ +import argparse import sys import cv2 import numpy as np import matplotlib.pyplot as plt from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance +from load_test_data import load_data # Allows imports from the style transfer submodule sys.path.append('DCT-Net') @@ -36,16 +38,30 @@ def plot_two_images(a: np.ndarray, b: np.ndarray): plt.show() -def compare_with_anime_characters(data: np.ndarray) -> int: - # Example will be one face from anime dataset - example = load_source('data/images/Aisaka, Taiga.jpg') - # TODO: Use a different face detection method for anime images - example_face = find_and_crop_face(example, 'haarcascades/lbpcascade_animeface.xml') - data_rescaled = cv2.resize(data, example_face.shape[:2]) - plot_two_images(example_face, data_rescaled) - print(histogram_comparison(data_rescaled, example_face)) - print(f'structural-similarity: {structural_similarity_index(data_rescaled, example_face)}') - print(f'euclidean-distance: {euclidean_distance(data_rescaled, example_face)}') +def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]: + all_metrics = [] + for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']): + current_result = { + 'name': label, + 'metrics': {} + } + # TODO: Use a different face detection method for anime images + # anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml') + anime_face = anime_image + source_rescaled = cv2.resize(source, anime_face.shape[:2]) + if verbose: + plot_two_images(anime_face, source_rescaled) + current_result['metrics'] = histogram_comparison(source_rescaled, anime_face) + current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face) + current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face) + all_metrics.append(current_result) + + return all_metrics + + +def get_top_results(all_metrics: list[dict], metric='correlation', count=1): + all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric]) + return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count])) def transfer_to_anime(img: np.ndarray): @@ -53,8 +69,38 @@ def transfer_to_anime(img: np.ndarray): return algo.cartoonize(img).astype(np.uint8) +def validate(test_set, anime_faces_set, metric='correlation'): + all_entries = len(test_set['values']) + correct = 0 + for test_image, test_label in zip(test_set['values'], test_set['labels']): + output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name'] + if output == test_label: + correct += 1 + + accuracy = correct / all_entries + print(f'Accuracy using {metric}: {accuracy * 100}%') + return accuracy + + if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', '--validate_only') + args = parser.parse_args() + anime_faces_set = load_data('data/images') + + if args.validate_only: + print('Validating') + test_set = load_data('test_set') + validate(test_set, anime_faces_set, 'structural-similarity') + validate(test_set, anime_faces_set, 'euclidean-distance') + validate(test_set, anime_faces_set, 'chi-square') + validate(test_set, anime_faces_set, 'correlation') + validate(test_set, anime_faces_set, 'intersection') + validate(test_set, anime_faces_set, 'bhattacharyya-distance') + exit(0) + source = load_source('UAM-Andre.jpg') source_anime = transfer_to_anime(source) source_face_anime = find_and_crop_face(source_anime) - print(compare_with_anime_characters(source_face_anime)) + results = compare_with_anime_characters(source_face_anime, anime_faces_set) + print(get_top_results(results, count=5))