diff --git a/face_detect.py b/face_detect.py new file mode 100644 index 0000000..b08d4bd --- /dev/null +++ b/face_detect.py @@ -0,0 +1,15 @@ +import cv2 +import numpy as np + + +def find_face_bbox(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml'): + data_gray = cv2.cvtColor(data, cv2.COLOR_RGB2GRAY) + face_cascade = cv2.CascadeClassifier(classifier_file) + face_coords = face_cascade.detectMultiScale(data_gray, 1.1, 3) + return max(face_coords, key=len) + + +def crop_face(data: np.ndarray, bounding_box) -> np.ndarray: + x, y, w, h = bounding_box + face = data[y:y + h, x:x + w] + return face diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000..1b8f485 --- /dev/null +++ b/helpers.py @@ -0,0 +1,12 @@ +import os +import sys + + +def no_stdout(func): + def wrapper(*args, **kwargs): + old_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + ret = func(*args, **kwargs) + sys.stdout = old_stdout + return ret + return wrapper diff --git a/load_test_data.py b/load_test_data.py index 300cd25..2429b89 100644 --- a/load_test_data.py +++ b/load_test_data.py @@ -5,7 +5,11 @@ import cv2 as cv from pathlib import Path -def load_data(input_dir, newSize=(64,64)): +def load_source(filename: str) -> np.ndarray: + return cv.imread(filename)[..., ::-1] + + +def load_data(input_dir): image_path = Path(input_dir) file_names = os.listdir(image_path) categories_name = [] @@ -27,8 +31,7 @@ def load_data(input_dir, newSize=(64,64)): for n in file_names: p = image_path / n - img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth - img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray + img = load_source(str(p)) # zwraca ndarry postaci xSize x ySize x colorDepth test_img.append(img) labels.append(n) diff --git a/main.py b/main.py index 59fcbbb..83a3603 100644 --- a/main.py +++ b/main.py @@ -3,8 +3,12 @@ import sys import cv2 import numpy as np -from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance -from load_test_data import load_data +from metrics import histogram_comparison, structural_similarity_index, euclidean_distance, AccuracyGatherer + +from face_detect import find_face_bbox, crop_face +from helpers import no_stdout +from load_test_data import load_data, load_source +from metrics import get_top_results from plots import plot_two_images, plot_results # Allows imports from the style transfer submodule @@ -13,21 +17,10 @@ sys.path.append('DCT-Net') from source.cartoonize import Cartoonizer -def load_source(filename: str) -> np.ndarray: - return cv2.imread(filename)[..., ::-1] +anime_transfer = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models') -def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray: - data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) - face_cascade = cv2.CascadeClassifier(classifier_file) - face = face_cascade.detectMultiScale(data_gray, 1.1, 3) - face = max(face, key=len) - x, y, w, h = face - face = data[y:y + h, x:x + w] - return face - - -def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]: +def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]: all_metrics = [] for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']): current_result = { @@ -37,7 +30,7 @@ def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, # TODO: Use a different face detection method for anime images # anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml') anime_face = anime_image - source_rescaled = cv2.resize(source, anime_face.shape[:2]) + source_rescaled = cv2.resize(source_image, anime_face.shape[:2]) if verbose: plot_two_images(anime_face, source_rescaled) current_result['metrics'] = histogram_comparison(source_rescaled, anime_face) @@ -48,61 +41,59 @@ def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, return all_metrics -def get_top_results(all_metrics: list[dict], metric='correlation', count=1): - all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric]) - return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count])) - - +@no_stdout def transfer_to_anime(img: np.ndarray): - algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models') - model_out = algo.cartoonize(img).astype(np.uint8) + model_out = anime_transfer.cartoonize(img).astype(np.uint8) return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB) -def validate(test_set, anime_faces_set, top_n=1): +def similarity_to_anime(source_image, anime_faces_set): + try: + source_face_bbox = find_face_bbox(source_image) + except ValueError: + return None + source_anime = transfer_to_anime(source_image) + source_face_anime = crop_face(source_anime, source_face_bbox) + return compare_with_anime_characters(source_face_anime, anime_faces_set) + + +def validate(test_set, anime_faces_set): all_entries = len(test_set['values']) - all_metric_names = [ - 'structural-similarity', - 'euclidean-distance', - 'chi-square', - 'correlation', - 'intersection', - 'bhattacharyya-distance' - ] - hits_per_metric = {metric: 0 for metric in all_metric_names} + accuracy = AccuracyGatherer(all_entries) for test_image, test_label in zip(test_set['values'], test_set['labels']): - test_results = compare_with_anime_characters(test_image, anime_faces_set) - top_results_all_metrics = {m: get_top_results(test_results, m, top_n) for m in all_metric_names} - for metric_name in all_metric_names: - top_current_metric_results = top_results_all_metrics[metric_name] - if any(map(lambda single_result: single_result['name'] == test_label, top_current_metric_results)): - hits_per_metric[metric_name] += 1 + test_results = similarity_to_anime(test_image, anime_faces_set) - all_metrics = {metric: hits_per_metric[metric] / all_entries for metric in all_metric_names} - print(f'Top {top_n} matches results:') - [print(f'\t{key}: {value*100}%') for key, value in all_metrics.items()] - return all_metrics + if test_results is None: + print(f"cannot find face for {test_label}") + all_entries -= 1 + continue + + accuracy.for_results(test_results, test_label) + + accuracy.count = all_entries + accuracy.print() -if __name__ == '__main__': + +def main(): parser = argparse.ArgumentParser() parser.add_argument('-v', '--validate_only') args = parser.parse_args() - anime_faces_set = load_data('data/croped_anime_faces', (256, 256)) + anime_faces_set = load_data('data/croped_anime_faces') if args.validate_only: print('Validating') test_set = load_data('test_set') - validate(test_set, anime_faces_set, 1) - validate(test_set, anime_faces_set, 3) - validate(test_set, anime_faces_set, 5) + validate(test_set, anime_faces_set) exit(0) - source = load_source('UAM-Andre.jpg') - source_anime = transfer_to_anime(source) - source_face_anime = find_and_crop_face(source_anime) - results = compare_with_anime_characters(source_face_anime, anime_faces_set) + source = load_source('test_set/Ayanokouji, Kiyotaka.jpg') + results = similarity_to_anime(source, anime_faces_set) method = 'structural-similarity' top_results = get_top_results(results, count=4, metric=method) print(top_results) - plot_results(source, source_anime, top_results, anime_faces_set, method) + plot_results(source, transfer_to_anime(source), top_results, anime_faces_set, method) + + +if __name__ == '__main__': + main() diff --git a/comparisons.py b/metrics.py similarity index 52% rename from comparisons.py rename to metrics.py index 7029786..8d6be1d 100644 --- a/comparisons.py +++ b/metrics.py @@ -40,3 +40,42 @@ def euclidean_distance(data_a: np.ndarray, data_b: np.ndarray) -> float: result += (histogram_a[i] - histogram_b[i]) ** 2 i += 1 return result[0] ** (1 / 2) + + +def get_top_results(all_metrics: list[dict], metric='correlation', count=1): + all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric]) + return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count])) + + +class AccuracyGatherer: + all_metric_names = [ + 'structural-similarity', + 'euclidean-distance', + 'chi-square', + 'correlation', + 'intersection', + 'bhattacharyya-distance' + ] + + def __init__(self, count, top_ks=(1, 3, 5)): + self.top_ks = top_ks + self.hits = {k: {metric: 0 for metric in AccuracyGatherer.all_metric_names} for k in top_ks} + self.count = count + + def print(self): + for k in self.top_ks: + all_metrics = {metric: self.hits[k][metric] / self.count for metric in AccuracyGatherer.all_metric_names} + print(f'Top {k} matches results:') + [print(f'\t{key}: {value * 100}%') for key, value in all_metrics.items()] + + def for_results(self, results, test_label): + top_results_all_metrics = { + k: {m: get_top_results(results, m, k) for m in AccuracyGatherer.all_metric_names} for k in self.top_ks + } + for metric_name in AccuracyGatherer.all_metric_names: + self.add_if_hit(top_results_all_metrics, test_label, metric_name) + + def add_if_hit(self, results, test_label, metric_name): + for k in self.top_ks: + if any(map(lambda single_result: single_result['name'] == test_label, results[k][metric_name])): + self.hits[k][metric_name] += 1