Merge pull request 'Mierzenie accuracy nie tylko na top 1 outpucie' (#7) from top-k-validation into main

Reviewed-on: #7
This commit is contained in:
Marcin Kostrzewski 2023-02-01 18:09:26 +01:00
commit 7e76f516fd

27
main.py
View File

@ -69,12 +69,12 @@ def transfer_to_anime(img: np.ndarray):
return algo.cartoonize(img).astype(np.uint8) return algo.cartoonize(img).astype(np.uint8)
def validate(test_set, anime_faces_set, metric='correlation'): def validate(test_set, anime_faces_set, metric='correlation', top_n=1):
all_entries = len(test_set['values']) all_entries = len(test_set['values'])
correct = 0 correct = 0
for test_image, test_label in zip(test_set['values'], test_set['labels']): for test_image, test_label in zip(test_set['values'], test_set['labels']):
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name'] output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric, top_n)
if output == test_label: if any(map(lambda single_result: single_result['name'] == test_label, output)):
correct += 1 correct += 1
accuracy = correct / all_entries accuracy = correct / all_entries
@ -82,6 +82,15 @@ def validate(test_set, anime_faces_set, metric='correlation'):
return accuracy return accuracy
def validate_all(test_set, anime_faces_set, metric='correlation', top_n=1):
validate(test_set, anime_faces_set, 'structural-similarity', top_n)
validate(test_set, anime_faces_set, 'euclidean-distance', top_n)
validate(test_set, anime_faces_set, 'chi-square', top_n)
validate(test_set, anime_faces_set, 'correlation', top_n)
validate(test_set, anime_faces_set, 'intersection', top_n)
validate(test_set, anime_faces_set, 'bhattacharyya-distance', top_n)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only') parser.add_argument('-v', '--validate_only')
@ -91,12 +100,12 @@ if __name__ == '__main__':
if args.validate_only: if args.validate_only:
print('Validating') print('Validating')
test_set = load_data('test_set') test_set = load_data('test_set')
validate(test_set, anime_faces_set, 'structural-similarity') print('Top 1 matches results:')
validate(test_set, anime_faces_set, 'euclidean-distance') validate_all(test_set, anime_faces_set, 'structural-similarity', 1)
validate(test_set, anime_faces_set, 'chi-square') print('Top 3 matches results:')
validate(test_set, anime_faces_set, 'correlation') validate_all(test_set, anime_faces_set, 'structural-similarity', 3)
validate(test_set, anime_faces_set, 'intersection') print('Top 5 matches results:')
validate(test_set, anime_faces_set, 'bhattacharyya-distance') validate_all(test_set, anime_faces_set, 'structural-similarity', 5)
exit(0) exit(0)
source = load_source('UAM-Andre.jpg') source = load_source('UAM-Andre.jpg')