Refactor and bug fixes

This commit is contained in:
Marcin Kostrzewski 2023-02-01 19:55:12 +01:00
parent 3817096c34
commit 7e9b63e43e
5 changed files with 116 additions and 56 deletions

15
face_detect.py Normal file
View File

@ -0,0 +1,15 @@
import cv2
import numpy as np
def find_face_bbox(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml'):
data_gray = cv2.cvtColor(data, cv2.COLOR_RGB2GRAY)
face_cascade = cv2.CascadeClassifier(classifier_file)
face_coords = face_cascade.detectMultiScale(data_gray, 1.1, 3)
return max(face_coords, key=len)
def crop_face(data: np.ndarray, bounding_box) -> np.ndarray:
x, y, w, h = bounding_box
face = data[y:y + h, x:x + w]
return face

12
helpers.py Normal file
View File

@ -0,0 +1,12 @@
import os
import sys
def no_stdout(func):
def wrapper(*args, **kwargs):
old_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
ret = func(*args, **kwargs)
sys.stdout = old_stdout
return ret
return wrapper

View File

@ -5,7 +5,11 @@ import cv2 as cv
from pathlib import Path
def load_data(input_dir, newSize=(64,64)):
def load_source(filename: str) -> np.ndarray:
return cv.imread(filename)[..., ::-1]
def load_data(input_dir):
image_path = Path(input_dir)
file_names = os.listdir(image_path)
categories_name = []
@ -27,8 +31,7 @@ def load_data(input_dir, newSize=(64,64)):
for n in file_names:
p = image_path / n
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
img = load_source(str(p)) # zwraca ndarry postaci xSize x ySize x colorDepth
test_img.append(img)
labels.append(n)

97
main.py
View File

@ -3,8 +3,12 @@ import sys
import cv2
import numpy as np
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
from load_test_data import load_data
from metrics import histogram_comparison, structural_similarity_index, euclidean_distance, AccuracyGatherer
from face_detect import find_face_bbox, crop_face
from helpers import no_stdout
from load_test_data import load_data, load_source
from metrics import get_top_results
from plots import plot_two_images, plot_results
# Allows imports from the style transfer submodule
@ -13,21 +17,10 @@ sys.path.append('DCT-Net')
from source.cartoonize import Cartoonizer
def load_source(filename: str) -> np.ndarray:
return cv2.imread(filename)[..., ::-1]
anime_transfer = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray:
data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
face_cascade = cv2.CascadeClassifier(classifier_file)
face = face_cascade.detectMultiScale(data_gray, 1.1, 3)
face = max(face, key=len)
x, y, w, h = face
face = data[y:y + h, x:x + w]
return face
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
all_metrics = []
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
current_result = {
@ -37,7 +30,7 @@ def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict,
# TODO: Use a different face detection method for anime images
# anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
anime_face = anime_image
source_rescaled = cv2.resize(source, anime_face.shape[:2])
source_rescaled = cv2.resize(source_image, anime_face.shape[:2])
if verbose:
plot_two_images(anime_face, source_rescaled)
current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
@ -48,61 +41,59 @@ def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict,
return all_metrics
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
@no_stdout
def transfer_to_anime(img: np.ndarray):
algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
model_out = algo.cartoonize(img).astype(np.uint8)
model_out = anime_transfer.cartoonize(img).astype(np.uint8)
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
def validate(test_set, anime_faces_set, top_n=1):
def similarity_to_anime(source_image, anime_faces_set):
try:
source_face_bbox = find_face_bbox(source_image)
except ValueError:
return None
source_anime = transfer_to_anime(source_image)
source_face_anime = crop_face(source_anime, source_face_bbox)
return compare_with_anime_characters(source_face_anime, anime_faces_set)
def validate(test_set, anime_faces_set):
all_entries = len(test_set['values'])
all_metric_names = [
'structural-similarity',
'euclidean-distance',
'chi-square',
'correlation',
'intersection',
'bhattacharyya-distance'
]
hits_per_metric = {metric: 0 for metric in all_metric_names}
accuracy = AccuracyGatherer(all_entries)
for test_image, test_label in zip(test_set['values'], test_set['labels']):
test_results = compare_with_anime_characters(test_image, anime_faces_set)
top_results_all_metrics = {m: get_top_results(test_results, m, top_n) for m in all_metric_names}
for metric_name in all_metric_names:
top_current_metric_results = top_results_all_metrics[metric_name]
if any(map(lambda single_result: single_result['name'] == test_label, top_current_metric_results)):
hits_per_metric[metric_name] += 1
test_results = similarity_to_anime(test_image, anime_faces_set)
all_metrics = {metric: hits_per_metric[metric] / all_entries for metric in all_metric_names}
print(f'Top {top_n} matches results:')
[print(f'\t{key}: {value*100}%') for key, value in all_metrics.items()]
return all_metrics
if test_results is None:
print(f"cannot find face for {test_label}")
all_entries -= 1
continue
accuracy.for_results(test_results, test_label)
accuracy.count = all_entries
accuracy.print()
if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only')
args = parser.parse_args()
anime_faces_set = load_data('data/croped_anime_faces', (256, 256))
anime_faces_set = load_data('data/croped_anime_faces')
if args.validate_only:
print('Validating')
test_set = load_data('test_set')
validate(test_set, anime_faces_set, 1)
validate(test_set, anime_faces_set, 3)
validate(test_set, anime_faces_set, 5)
validate(test_set, anime_faces_set)
exit(0)
source = load_source('UAM-Andre.jpg')
source_anime = transfer_to_anime(source)
source_face_anime = find_and_crop_face(source_anime)
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
source = load_source('test_set/Ayanokouji, Kiyotaka.jpg')
results = similarity_to_anime(source, anime_faces_set)
method = 'structural-similarity'
top_results = get_top_results(results, count=4, metric=method)
print(top_results)
plot_results(source, source_anime, top_results, anime_faces_set, method)
plot_results(source, transfer_to_anime(source), top_results, anime_faces_set, method)
if __name__ == '__main__':
main()

View File

@ -40,3 +40,42 @@ def euclidean_distance(data_a: np.ndarray, data_b: np.ndarray) -> float:
result += (histogram_a[i] - histogram_b[i]) ** 2
i += 1
return result[0] ** (1 / 2)
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
class AccuracyGatherer:
all_metric_names = [
'structural-similarity',
'euclidean-distance',
'chi-square',
'correlation',
'intersection',
'bhattacharyya-distance'
]
def __init__(self, count, top_ks=(1, 3, 5)):
self.top_ks = top_ks
self.hits = {k: {metric: 0 for metric in AccuracyGatherer.all_metric_names} for k in top_ks}
self.count = count
def print(self):
for k in self.top_ks:
all_metrics = {metric: self.hits[k][metric] / self.count for metric in AccuracyGatherer.all_metric_names}
print(f'Top {k} matches results:')
[print(f'\t{key}: {value * 100}%') for key, value in all_metrics.items()]
def for_results(self, results, test_label):
top_results_all_metrics = {
k: {m: get_top_results(results, m, k) for m in AccuracyGatherer.all_metric_names} for k in self.top_ks
}
for metric_name in AccuracyGatherer.all_metric_names:
self.add_if_hit(top_results_all_metrics, test_label, metric_name)
def add_if_hit(self, results, test_label, metric_name):
for k in self.top_ks:
if any(map(lambda single_result: single_result['name'] == test_label, results[k][metric_name])):
self.hits[k][metric_name] += 1