From 9c4d70a21b8c37057e03793fac7f2ac33afc5462 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Wed, 1 Feb 2023 13:47:51 +0100 Subject: [PATCH] Add validation for top-k results --- main.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 2513b12..48200f0 100644 --- a/main.py +++ b/main.py @@ -69,19 +69,28 @@ def transfer_to_anime(img: np.ndarray): return algo.cartoonize(img).astype(np.uint8) -def validate(test_set, anime_faces_set, metric='correlation'): +def validate(test_set, anime_faces_set, metric='correlation', top_n=1): all_entries = len(test_set['values']) correct = 0 for test_image, test_label in zip(test_set['values'], test_set['labels']): - output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name'] - if output == test_label: - correct += 1 + output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric, top_n) + if any(map(lambda single_result: single_result['name'] == test_label, output)): + correct += 1 accuracy = correct / all_entries print(f'Accuracy using {metric}: {accuracy * 100}%') return accuracy +def validate_all(test_set, anime_faces_set, metric='correlation', top_n=1): + validate(test_set, anime_faces_set, 'structural-similarity', top_n) + validate(test_set, anime_faces_set, 'euclidean-distance', top_n) + validate(test_set, anime_faces_set, 'chi-square', top_n) + validate(test_set, anime_faces_set, 'correlation', top_n) + validate(test_set, anime_faces_set, 'intersection', top_n) + validate(test_set, anime_faces_set, 'bhattacharyya-distance', top_n) + + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-v', '--validate_only') @@ -91,12 +100,12 @@ if __name__ == '__main__': if args.validate_only: print('Validating') test_set = load_data('test_set') - validate(test_set, anime_faces_set, 'structural-similarity') - validate(test_set, anime_faces_set, 'euclidean-distance') - validate(test_set, anime_faces_set, 'chi-square') - validate(test_set, anime_faces_set, 'correlation') - validate(test_set, anime_faces_set, 'intersection') - validate(test_set, anime_faces_set, 'bhattacharyya-distance') + print('Top 1 matches results:') + validate_all(test_set, anime_faces_set, 'structural-similarity', 1) + print('Top 3 matches results:') + validate_all(test_set, anime_faces_set, 'structural-similarity', 3) + print('Top 5 matches results:') + validate_all(test_set, anime_faces_set, 'structural-similarity', 5) exit(0) source = load_source('UAM-Andre.jpg')