Merge pull request 'Porównanie z całym datasetem z twarzami oraz walidacja' (#4) from full-dataset-comparison into main

Reviewed-on: #4
This commit is contained in:
Mateusz Tylka 2023-02-01 11:03:06 +01:00
commit e63892f806
2 changed files with 57 additions and 12 deletions

View File

@ -29,7 +29,6 @@ def load_data(input_dir, newSize=(64,64)):
p = image_path / n p = image_path / n
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
img = img / 255 # type: ignore #normalizacja
test_img.append(img) test_img.append(img)
labels.append(n) labels.append(n)

66
main.py
View File

@ -1,9 +1,11 @@
import argparse
import sys import sys
import cv2 import cv2
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
from load_test_data import load_data
# Allows imports from the style transfer submodule # Allows imports from the style transfer submodule
sys.path.append('DCT-Net') sys.path.append('DCT-Net')
@ -36,16 +38,30 @@ def plot_two_images(a: np.ndarray, b: np.ndarray):
plt.show() plt.show()
def compare_with_anime_characters(data: np.ndarray) -> int: def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
# Example will be one face from anime dataset all_metrics = []
example = load_source('data/images/Aisaka, Taiga.jpg') for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
current_result = {
'name': label,
'metrics': {}
}
# TODO: Use a different face detection method for anime images # TODO: Use a different face detection method for anime images
example_face = find_and_crop_face(example, 'haarcascades/lbpcascade_animeface.xml') # anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
data_rescaled = cv2.resize(data, example_face.shape[:2]) anime_face = anime_image
plot_two_images(example_face, data_rescaled) source_rescaled = cv2.resize(source, anime_face.shape[:2])
print(histogram_comparison(data_rescaled, example_face)) if verbose:
print(f'structural-similarity: {structural_similarity_index(data_rescaled, example_face)}') plot_two_images(anime_face, source_rescaled)
print(f'euclidean-distance: {euclidean_distance(data_rescaled, example_face)}') current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face)
current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face)
all_metrics.append(current_result)
return all_metrics
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
def transfer_to_anime(img: np.ndarray): def transfer_to_anime(img: np.ndarray):
@ -53,8 +69,38 @@ def transfer_to_anime(img: np.ndarray):
return algo.cartoonize(img).astype(np.uint8) return algo.cartoonize(img).astype(np.uint8)
def validate(test_set, anime_faces_set, metric='correlation'):
all_entries = len(test_set['values'])
correct = 0
for test_image, test_label in zip(test_set['values'], test_set['labels']):
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name']
if output == test_label:
correct += 1
accuracy = correct / all_entries
print(f'Accuracy using {metric}: {accuracy * 100}%')
return accuracy
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only')
args = parser.parse_args()
anime_faces_set = load_data('data/images')
if args.validate_only:
print('Validating')
test_set = load_data('test_set')
validate(test_set, anime_faces_set, 'structural-similarity')
validate(test_set, anime_faces_set, 'euclidean-distance')
validate(test_set, anime_faces_set, 'chi-square')
validate(test_set, anime_faces_set, 'correlation')
validate(test_set, anime_faces_set, 'intersection')
validate(test_set, anime_faces_set, 'bhattacharyya-distance')
exit(0)
source = load_source('UAM-Andre.jpg') source = load_source('UAM-Andre.jpg')
source_anime = transfer_to_anime(source) source_anime = transfer_to_anime(source)
source_face_anime = find_and_crop_face(source_anime) source_face_anime = find_and_crop_face(source_anime)
print(compare_with_anime_characters(source_face_anime)) results = compare_with_anime_characters(source_face_anime, anime_faces_set)
print(get_top_results(results, count=5))