import argparse import sys import cv2 import matplotlib.pyplot as plt import numpy as np from metrics import histogram_comparison, structural_similarity_index, euclidean_distance, AccuracyGatherer from face_detect import find_face_bbox, crop_face from helpers import no_stdout from load_test_data import load_data, load_source from metrics import get_top_results from plots import plot_two_images, plot_results # Allows imports from the style transfer submodule sys.path.append('DCT-Net') from source.cartoonize import Cartoonizer anime_transfer = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models') def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]: all_metrics = [] for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']): current_result = { 'name': label, 'metrics': {} } # TODO: Use a different face detection method for anime images # anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml') anime_face = anime_image source_rescaled = cv2.resize(source_image, anime_face.shape[:2]) if verbose: plot_two_images(anime_face, source_rescaled) current_result['metrics'] = histogram_comparison(source_rescaled, anime_face) current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face) current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face) all_metrics.append(current_result) return all_metrics @no_stdout def transfer_to_anime(img: np.ndarray): model_out = anime_transfer.cartoonize(img).astype(np.uint8) return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB) def similarity_to_anime(source_image, anime_faces_set, debug=False): try: source_face_bbox = find_face_bbox(source_image) except ValueError: return None source_anime = transfer_to_anime(source_image) source_face_anime = crop_face(source_anime, source_face_bbox) if debug: source_image_with_box = source_image.copy() x, y, w, h = source_face_bbox cv2.rectangle(source_image_with_box, (x, y), (x + w, y + h), (255, 0, 0), 2) plt.figure(figsize=[12, 4]) plt.subplot(131) plt.imshow(source_image_with_box) plt.subplot(132) plt.imshow(source_anime) plt.subplot(133) plt.imshow(source_face_anime) plt.show() return compare_with_anime_characters(source_face_anime, anime_faces_set, verbose=debug) def validate(test_set, anime_faces_set): all_entries = len(test_set['values']) accuracy = AccuracyGatherer(all_entries) for test_image, test_label in zip(test_set['values'], test_set['labels']): test_results = similarity_to_anime(test_image, anime_faces_set) if test_results is None: print(f"cannot find face for {test_label}") all_entries -= 1 continue accuracy.for_results(test_results, test_label) accuracy.count = all_entries accuracy.print() def main(): parser = argparse.ArgumentParser() parser.add_argument('-v', '--validate_only') args = parser.parse_args() anime_faces_set = load_data('data/croped_anime_faces') if args.validate_only: print('Validating') test_set = load_data('test_set') validate(test_set, anime_faces_set) exit(0) source = load_source('test_set/Ayanokouji, Kiyotaka.jpg') results = similarity_to_anime(source, anime_faces_set) method = 'correlation' top_results = get_top_results(results, count=4, metric=method) print(top_results) plot_results(source, transfer_to_anime(source), top_results, anime_faces_set, method) if __name__ == '__main__': main()