From e212795fab7f72f52cc0dc17eb082e83b04c01b8 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Tue, 31 Jan 2023 21:08:01 +0100 Subject: [PATCH] Add validation --- main.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index e53c97b..20ce939 100644 --- a/main.py +++ b/main.py @@ -37,10 +37,9 @@ def plot_two_images(a: np.ndarray, b: np.ndarray): plt.show() -def compare_with_anime_characters(source: np.ndarray, verbose=False) -> list[dict]: - dataset = load_data('data/images') +def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]: all_metrics = [] - for anime_image, label in zip(dataset['values'], dataset['labels']): + for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']): current_result = { 'name': label, 'metrics': {} @@ -69,9 +68,32 @@ def transfer_to_anime(img: np.ndarray): return algo.cartoonize(img).astype(np.uint8) +def validate(test_set, anime_faces_set, metric='correlation'): + all_entries = len(test_set['values']) + correct = 0 + for test_image, test_label in zip(test_set['values'], test_set['labels']): + output = get_top_results(compare_with_anime_characters(test_image), metric)[0]['name'] + if output == test_label: + correct += 1 + + accuracy = correct / all_entries + print(f'Accuracy using {metric}: {accuracy * 100}%') + return accuracy + + if __name__ == '__main__': + anime_faces_set = load_data('data/images') + + # Uncomment for validation (takes a while) + # test_set = load_data('test_set') + # validate(test_set, anime_faces_set, 'structural-similarity') + # validate(test_set, anime_faces_set, 'euclidean-distance') + # validate(test_set, anime_faces_set, 'chi-square') + # validate(test_set, anime_faces_set, 'correlation') + # validate(test_set, anime_faces_set, 'intersection') + # validate(test_set, anime_faces_set, 'bhattacharyya-distance') source = load_source('UAM-Andre.jpg') source_anime = transfer_to_anime(source) source_face_anime = find_and_crop_face(source_anime) - results = compare_with_anime_characters(source_face_anime) + results = compare_with_anime_characters(source_face_anime, anime_faces_set) print(get_top_results(results, count=5))