Fixes #8

Merged
s444409 merged 9 commits from flipping into pretty-results 2023-02-01 22:14:44 +01:00
Showing only changes of commit 3817096c34 - Show all commits

48
main.py
View File

@ -59,43 +59,43 @@ def transfer_to_anime(img: np.ndarray):
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
def validate(test_set, anime_faces_set, metric='correlation', top_n=1):
def validate(test_set, anime_faces_set, top_n=1):
all_entries = len(test_set['values'])
correct = 0
all_metric_names = [
'structural-similarity',
'euclidean-distance',
'chi-square',
'correlation',
'intersection',
'bhattacharyya-distance'
]
hits_per_metric = {metric: 0 for metric in all_metric_names}
for test_image, test_label in zip(test_set['values'], test_set['labels']):
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric, top_n)
if any(map(lambda single_result: single_result['name'] == test_label, output)):
correct += 1
test_results = compare_with_anime_characters(test_image, anime_faces_set)
top_results_all_metrics = {m: get_top_results(test_results, m, top_n) for m in all_metric_names}
for metric_name in all_metric_names:
top_current_metric_results = top_results_all_metrics[metric_name]
if any(map(lambda single_result: single_result['name'] == test_label, top_current_metric_results)):
hits_per_metric[metric_name] += 1
accuracy = correct / all_entries
print(f'Accuracy using {metric}: {accuracy * 100}%')
return accuracy
def validate_all(test_set, anime_faces_set, metric='correlation', top_n=1):
validate(test_set, anime_faces_set, 'structural-similarity', top_n)
validate(test_set, anime_faces_set, 'euclidean-distance', top_n)
validate(test_set, anime_faces_set, 'chi-square', top_n)
validate(test_set, anime_faces_set, 'correlation', top_n)
validate(test_set, anime_faces_set, 'intersection', top_n)
validate(test_set, anime_faces_set, 'bhattacharyya-distance', top_n)
all_metrics = {metric: hits_per_metric[metric] / all_entries for metric in all_metric_names}
print(f'Top {top_n} matches results:')
[print(f'\t{key}: {value*100}%') for key, value in all_metrics.items()]
return all_metrics
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only')
args = parser.parse_args()
anime_faces_set = load_data('data/images', (256, 256))
anime_faces_set = load_data('data/croped_anime_faces', (256, 256))
if args.validate_only:
print('Validating')
test_set = load_data('test_set')
print('Top 1 matches results:')
validate_all(test_set, anime_faces_set, 'structural-similarity', 1)
print('Top 3 matches results:')
validate_all(test_set, anime_faces_set, 'structural-similarity', 3)
print('Top 5 matches results:')
validate_all(test_set, anime_faces_set, 'structural-similarity', 5)
validate(test_set, anime_faces_set, 1)
validate(test_set, anime_faces_set, 3)
validate(test_set, anime_faces_set, 5)
exit(0)
source = load_source('UAM-Andre.jpg')