Validation using argument

This commit is contained in:
Marcin Kostrzewski 2023-01-31 21:13:30 +01:00
parent e212795fab
commit e6f4ea8361

25
main.py
View File

@ -1,3 +1,4 @@
import argparse
import sys import sys
import cv2 import cv2
import numpy as np import numpy as np
@ -72,7 +73,7 @@ def validate(test_set, anime_faces_set, metric='correlation'):
all_entries = len(test_set['values']) all_entries = len(test_set['values'])
correct = 0 correct = 0
for test_image, test_label in zip(test_set['values'], test_set['labels']): for test_image, test_label in zip(test_set['values'], test_set['labels']):
output = get_top_results(compare_with_anime_characters(test_image), metric)[0]['name'] output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name']
if output == test_label: if output == test_label:
correct += 1 correct += 1
@ -82,16 +83,22 @@ def validate(test_set, anime_faces_set, metric='correlation'):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only')
args = parser.parse_args()
anime_faces_set = load_data('data/images') anime_faces_set = load_data('data/images')
# Uncomment for validation (takes a while) if args.validate_only:
# test_set = load_data('test_set') print('Validating')
# validate(test_set, anime_faces_set, 'structural-similarity') test_set = load_data('test_set')
# validate(test_set, anime_faces_set, 'euclidean-distance') validate(test_set, anime_faces_set, 'structural-similarity')
# validate(test_set, anime_faces_set, 'chi-square') validate(test_set, anime_faces_set, 'euclidean-distance')
# validate(test_set, anime_faces_set, 'correlation') validate(test_set, anime_faces_set, 'chi-square')
# validate(test_set, anime_faces_set, 'intersection') validate(test_set, anime_faces_set, 'correlation')
# validate(test_set, anime_faces_set, 'bhattacharyya-distance') validate(test_set, anime_faces_set, 'intersection')
validate(test_set, anime_faces_set, 'bhattacharyya-distance')
exit(0)
source = load_source('UAM-Andre.jpg') source = load_source('UAM-Andre.jpg')
source_anime = transfer_to_anime(source) source_anime = transfer_to_anime(source)
source_face_anime = find_and_crop_face(source_anime) source_face_anime = find_and_crop_face(source_anime)