diff --git a/main.py b/main.py index 20ce939..2513b12 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import argparse import sys import cv2 import numpy as np @@ -72,7 +73,7 @@ def validate(test_set, anime_faces_set, metric='correlation'): all_entries = len(test_set['values']) correct = 0 for test_image, test_label in zip(test_set['values'], test_set['labels']): - output = get_top_results(compare_with_anime_characters(test_image), metric)[0]['name'] + output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name'] if output == test_label: correct += 1 @@ -82,16 +83,22 @@ def validate(test_set, anime_faces_set, metric='correlation'): if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', '--validate_only') + args = parser.parse_args() anime_faces_set = load_data('data/images') - # Uncomment for validation (takes a while) - # test_set = load_data('test_set') - # validate(test_set, anime_faces_set, 'structural-similarity') - # validate(test_set, anime_faces_set, 'euclidean-distance') - # validate(test_set, anime_faces_set, 'chi-square') - # validate(test_set, anime_faces_set, 'correlation') - # validate(test_set, anime_faces_set, 'intersection') - # validate(test_set, anime_faces_set, 'bhattacharyya-distance') + if args.validate_only: + print('Validating') + test_set = load_data('test_set') + validate(test_set, anime_faces_set, 'structural-similarity') + validate(test_set, anime_faces_set, 'euclidean-distance') + validate(test_set, anime_faces_set, 'chi-square') + validate(test_set, anime_faces_set, 'correlation') + validate(test_set, anime_faces_set, 'intersection') + validate(test_set, anime_faces_set, 'bhattacharyya-distance') + exit(0) + source = load_source('UAM-Andre.jpg') source_anime = transfer_to_anime(source) source_face_anime = find_and_crop_face(source_anime)