Add validation
This commit is contained in:
parent
49e337e5e9
commit
e212795fab
30
main.py
30
main.py
@ -37,10 +37,9 @@ def plot_two_images(a: np.ndarray, b: np.ndarray):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def compare_with_anime_characters(source: np.ndarray, verbose=False) -> list[dict]:
|
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
|
||||||
dataset = load_data('data/images')
|
|
||||||
all_metrics = []
|
all_metrics = []
|
||||||
for anime_image, label in zip(dataset['values'], dataset['labels']):
|
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
|
||||||
current_result = {
|
current_result = {
|
||||||
'name': label,
|
'name': label,
|
||||||
'metrics': {}
|
'metrics': {}
|
||||||
@ -69,9 +68,32 @@ def transfer_to_anime(img: np.ndarray):
|
|||||||
return algo.cartoonize(img).astype(np.uint8)
|
return algo.cartoonize(img).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(test_set, anime_faces_set, metric='correlation'):
|
||||||
|
all_entries = len(test_set['values'])
|
||||||
|
correct = 0
|
||||||
|
for test_image, test_label in zip(test_set['values'], test_set['labels']):
|
||||||
|
output = get_top_results(compare_with_anime_characters(test_image), metric)[0]['name']
|
||||||
|
if output == test_label:
|
||||||
|
correct += 1
|
||||||
|
|
||||||
|
accuracy = correct / all_entries
|
||||||
|
print(f'Accuracy using {metric}: {accuracy * 100}%')
|
||||||
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
anime_faces_set = load_data('data/images')
|
||||||
|
|
||||||
|
# Uncomment for validation (takes a while)
|
||||||
|
# test_set = load_data('test_set')
|
||||||
|
# validate(test_set, anime_faces_set, 'structural-similarity')
|
||||||
|
# validate(test_set, anime_faces_set, 'euclidean-distance')
|
||||||
|
# validate(test_set, anime_faces_set, 'chi-square')
|
||||||
|
# validate(test_set, anime_faces_set, 'correlation')
|
||||||
|
# validate(test_set, anime_faces_set, 'intersection')
|
||||||
|
# validate(test_set, anime_faces_set, 'bhattacharyya-distance')
|
||||||
source = load_source('UAM-Andre.jpg')
|
source = load_source('UAM-Andre.jpg')
|
||||||
source_anime = transfer_to_anime(source)
|
source_anime = transfer_to_anime(source)
|
||||||
source_face_anime = find_and_crop_face(source_anime)
|
source_face_anime = find_and_crop_face(source_anime)
|
||||||
results = compare_with_anime_characters(source_face_anime)
|
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
|
||||||
print(get_top_results(results, count=5))
|
print(get_top_results(results, count=5))
|
||||||
|
Loading…
Reference in New Issue
Block a user