import argparse import sys import cv2 import numpy as np import matplotlib.pyplot as plt from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance from load_test_data import load_data # Allows imports from the style transfer submodule sys.path.append('DCT-Net') from source.cartoonize import Cartoonizer def load_source(filename: str) -> np.ndarray: return cv2.imread(filename)[..., ::-1] def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray: data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) face_cascade = cv2.CascadeClassifier(classifier_file) face = face_cascade.detectMultiScale(data_gray, 1.1, 3) face = max(face, key=len) x, y, w, h = face face = data[y:y + h, x:x + w] return face def plot_two_images(a: np.ndarray, b: np.ndarray): plt.figure(figsize=[10, 10]) plt.subplot(121) plt.imshow(a) plt.title("A") plt.subplot(122) plt.imshow(b) plt.title("B") plt.show() def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]: all_metrics = [] for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']): current_result = { 'name': label, 'metrics': {} } # TODO: Use a different face detection method for anime images # anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml') anime_face = anime_image source_rescaled = cv2.resize(source, anime_face.shape[:2]) if verbose: plot_two_images(anime_face, source_rescaled) current_result['metrics'] = histogram_comparison(source_rescaled, anime_face) current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face) current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face) all_metrics.append(current_result) return all_metrics def get_top_results(all_metrics: list[dict], metric='correlation', count=1): all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric]) return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count])) def transfer_to_anime(img: np.ndarray): algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models') return algo.cartoonize(img).astype(np.uint8) def validate(test_set, anime_faces_set, metric='correlation'): all_entries = len(test_set['values']) correct = 0 for test_image, test_label in zip(test_set['values'], test_set['labels']): output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name'] if output == test_label: correct += 1 accuracy = correct / all_entries print(f'Accuracy using {metric}: {accuracy * 100}%') return accuracy if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-v', '--validate_only') args = parser.parse_args() anime_faces_set = load_data('data/images') if args.validate_only: print('Validating') test_set = load_data('test_set') validate(test_set, anime_faces_set, 'structural-similarity') validate(test_set, anime_faces_set, 'euclidean-distance') validate(test_set, anime_faces_set, 'chi-square') validate(test_set, anime_faces_set, 'correlation') validate(test_set, anime_faces_set, 'intersection') validate(test_set, anime_faces_set, 'bhattacharyya-distance') exit(0) source = load_source('UAM-Andre.jpg') source_anime = transfer_to_anime(source) source_face_anime = find_and_crop_face(source_anime) results = compare_with_anime_characters(source_face_anime, anime_faces_set) print(get_top_results(results, count=5))