wko_anime-face-similarity/main.py

107 lines
3.7 KiB
Python
Raw Normal View History

2023-01-31 21:13:30 +01:00
import argparse
2023-01-15 12:40:35 +01:00
import sys
import cv2
import numpy as np
2023-01-29 22:43:45 +01:00
import matplotlib.pyplot as plt
2023-01-29 22:57:29 +01:00
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
from load_test_data import load_data
2023-01-15 12:40:35 +01:00
2023-01-29 21:23:11 +01:00
# Allows imports from the style transfer submodule
sys.path.append('DCT-Net')
2023-01-15 12:40:35 +01:00
from source.cartoonize import Cartoonizer
2023-01-29 21:14:30 +01:00
2023-01-15 12:40:35 +01:00
def load_source(filename: str) -> np.ndarray:
2023-01-29 21:23:11 +01:00
return cv2.imread(filename)[..., ::-1]
2023-01-15 12:40:35 +01:00
2023-01-29 22:43:45 +01:00
def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray:
2023-01-29 15:17:39 +01:00
data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
2023-01-29 22:43:45 +01:00
face_cascade = cv2.CascadeClassifier(classifier_file)
face = face_cascade.detectMultiScale(data_gray, 1.1, 3)
2023-01-29 15:17:39 +01:00
face = max(face, key=len)
2023-01-29 21:14:30 +01:00
x, y, w, h = face
2023-01-29 15:17:39 +01:00
face = data[y:y + h, x:x + w]
2023-01-29 18:58:45 +01:00
return face
2023-01-15 12:40:35 +01:00
2023-01-29 22:43:45 +01:00
def plot_two_images(a: np.ndarray, b: np.ndarray):
plt.figure(figsize=[10, 10])
plt.subplot(121)
plt.imshow(a)
plt.title("A")
plt.subplot(122)
plt.imshow(b)
plt.title("B")
plt.show()
2023-01-31 21:08:01 +01:00
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
all_metrics = []
2023-01-31 21:08:01 +01:00
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
current_result = {
'name': label,
'metrics': {}
}
# TODO: Use a different face detection method for anime images
# anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
anime_face = anime_image
source_rescaled = cv2.resize(source, anime_face.shape[:2])
if verbose:
plot_two_images(anime_face, source_rescaled)
current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face)
current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face)
all_metrics.append(current_result)
return all_metrics
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
2023-01-15 12:40:35 +01:00
2023-01-29 15:17:39 +01:00
def transfer_to_anime(img: np.ndarray):
2023-01-29 21:23:11 +01:00
algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
2023-01-29 22:43:45 +01:00
return algo.cartoonize(img).astype(np.uint8)
2023-01-15 12:40:35 +01:00
2023-01-31 21:08:01 +01:00
def validate(test_set, anime_faces_set, metric='correlation'):
all_entries = len(test_set['values'])
correct = 0
for test_image, test_label in zip(test_set['values'], test_set['labels']):
2023-01-31 21:13:30 +01:00
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name']
2023-01-31 21:08:01 +01:00
if output == test_label:
correct += 1
accuracy = correct / all_entries
print(f'Accuracy using {metric}: {accuracy * 100}%')
return accuracy
2023-01-15 12:40:35 +01:00
if __name__ == '__main__':
2023-01-31 21:13:30 +01:00
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only')
args = parser.parse_args()
2023-01-31 21:08:01 +01:00
anime_faces_set = load_data('data/images')
2023-01-31 21:13:30 +01:00
if args.validate_only:
print('Validating')
test_set = load_data('test_set')
validate(test_set, anime_faces_set, 'structural-similarity')
validate(test_set, anime_faces_set, 'euclidean-distance')
validate(test_set, anime_faces_set, 'chi-square')
validate(test_set, anime_faces_set, 'correlation')
validate(test_set, anime_faces_set, 'intersection')
validate(test_set, anime_faces_set, 'bhattacharyya-distance')
exit(0)
2023-01-29 22:43:45 +01:00
source = load_source('UAM-Andre.jpg')
source_anime = transfer_to_anime(source)
source_face_anime = find_and_crop_face(source_anime)
2023-01-31 21:08:01 +01:00
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
print(get_top_results(results, count=5))