Fixes #8
27
main.py
27
main.py
@ -59,12 +59,12 @@ def transfer_to_anime(img: np.ndarray):
|
|||||||
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
|
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
|
||||||
def validate(test_set, anime_faces_set, metric='correlation'):
|
def validate(test_set, anime_faces_set, metric='correlation', top_n=1):
|
||||||
all_entries = len(test_set['values'])
|
all_entries = len(test_set['values'])
|
||||||
correct = 0
|
correct = 0
|
||||||
for test_image, test_label in zip(test_set['values'], test_set['labels']):
|
for test_image, test_label in zip(test_set['values'], test_set['labels']):
|
||||||
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name']
|
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric, top_n)
|
||||||
if output == test_label:
|
if any(map(lambda single_result: single_result['name'] == test_label, output)):
|
||||||
correct += 1
|
correct += 1
|
||||||
|
|
||||||
accuracy = correct / all_entries
|
accuracy = correct / all_entries
|
||||||
@ -72,6 +72,15 @@ def validate(test_set, anime_faces_set, metric='correlation'):
|
|||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def validate_all(test_set, anime_faces_set, metric='correlation', top_n=1):
|
||||||
|
validate(test_set, anime_faces_set, 'structural-similarity', top_n)
|
||||||
|
validate(test_set, anime_faces_set, 'euclidean-distance', top_n)
|
||||||
|
validate(test_set, anime_faces_set, 'chi-square', top_n)
|
||||||
|
validate(test_set, anime_faces_set, 'correlation', top_n)
|
||||||
|
validate(test_set, anime_faces_set, 'intersection', top_n)
|
||||||
|
validate(test_set, anime_faces_set, 'bhattacharyya-distance', top_n)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-v', '--validate_only')
|
parser.add_argument('-v', '--validate_only')
|
||||||
@ -81,12 +90,12 @@ if __name__ == '__main__':
|
|||||||
if args.validate_only:
|
if args.validate_only:
|
||||||
print('Validating')
|
print('Validating')
|
||||||
test_set = load_data('test_set')
|
test_set = load_data('test_set')
|
||||||
validate(test_set, anime_faces_set, 'structural-similarity')
|
print('Top 1 matches results:')
|
||||||
validate(test_set, anime_faces_set, 'euclidean-distance')
|
validate_all(test_set, anime_faces_set, 'structural-similarity', 1)
|
||||||
validate(test_set, anime_faces_set, 'chi-square')
|
print('Top 3 matches results:')
|
||||||
validate(test_set, anime_faces_set, 'correlation')
|
validate_all(test_set, anime_faces_set, 'structural-similarity', 3)
|
||||||
validate(test_set, anime_faces_set, 'intersection')
|
print('Top 5 matches results:')
|
||||||
validate(test_set, anime_faces_set, 'bhattacharyya-distance')
|
validate_all(test_set, anime_faces_set, 'structural-similarity', 5)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
source = load_source('UAM-Andre.jpg')
|
source = load_source('UAM-Andre.jpg')
|
||||||
|
Loading…
Reference in New Issue
Block a user