From 3817096c348c596fba810cd0d43bf2cf4c84d263 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Wed, 1 Feb 2023 18:42:07 +0100 Subject: [PATCH] Faster validation --- main.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index 43524dc..59fcbbb 100644 --- a/main.py +++ b/main.py @@ -59,43 +59,43 @@ def transfer_to_anime(img: np.ndarray): return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB) -def validate(test_set, anime_faces_set, metric='correlation', top_n=1): +def validate(test_set, anime_faces_set, top_n=1): all_entries = len(test_set['values']) - correct = 0 + all_metric_names = [ + 'structural-similarity', + 'euclidean-distance', + 'chi-square', + 'correlation', + 'intersection', + 'bhattacharyya-distance' + ] + hits_per_metric = {metric: 0 for metric in all_metric_names} for test_image, test_label in zip(test_set['values'], test_set['labels']): - output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric, top_n) - if any(map(lambda single_result: single_result['name'] == test_label, output)): - correct += 1 + test_results = compare_with_anime_characters(test_image, anime_faces_set) + top_results_all_metrics = {m: get_top_results(test_results, m, top_n) for m in all_metric_names} + for metric_name in all_metric_names: + top_current_metric_results = top_results_all_metrics[metric_name] + if any(map(lambda single_result: single_result['name'] == test_label, top_current_metric_results)): + hits_per_metric[metric_name] += 1 - accuracy = correct / all_entries - print(f'Accuracy using {metric}: {accuracy * 100}%') - return accuracy - - -def validate_all(test_set, anime_faces_set, metric='correlation', top_n=1): - validate(test_set, anime_faces_set, 'structural-similarity', top_n) - validate(test_set, anime_faces_set, 'euclidean-distance', top_n) - validate(test_set, anime_faces_set, 'chi-square', top_n) - validate(test_set, anime_faces_set, 'correlation', top_n) - validate(test_set, anime_faces_set, 'intersection', top_n) - validate(test_set, anime_faces_set, 'bhattacharyya-distance', top_n) + all_metrics = {metric: hits_per_metric[metric] / all_entries for metric in all_metric_names} + print(f'Top {top_n} matches results:') + [print(f'\t{key}: {value*100}%') for key, value in all_metrics.items()] + return all_metrics if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-v', '--validate_only') args = parser.parse_args() - anime_faces_set = load_data('data/images', (256, 256)) + anime_faces_set = load_data('data/croped_anime_faces', (256, 256)) if args.validate_only: print('Validating') test_set = load_data('test_set') - print('Top 1 matches results:') - validate_all(test_set, anime_faces_set, 'structural-similarity', 1) - print('Top 3 matches results:') - validate_all(test_set, anime_faces_set, 'structural-similarity', 3) - print('Top 5 matches results:') - validate_all(test_set, anime_faces_set, 'structural-similarity', 5) + validate(test_set, anime_faces_set, 1) + validate(test_set, anime_faces_set, 3) + validate(test_set, anime_faces_set, 5) exit(0) source = load_source('UAM-Andre.jpg')