Compare commits
No commits in common. "main" and "top-k-validation" have entirely different histories.
main
...
top-k-vali
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,5 @@
|
||||
data
|
||||
.idea
|
||||
.yoloface
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
@ -40,42 +40,3 @@ def euclidean_distance(data_a: np.ndarray, data_b: np.ndarray) -> float:
|
||||
result += (histogram_a[i] - histogram_b[i]) ** 2
|
||||
i += 1
|
||||
return result[0] ** (1 / 2)
|
||||
|
||||
|
||||
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
|
||||
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
|
||||
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
|
||||
|
||||
|
||||
class AccuracyGatherer:
|
||||
all_metric_names = [
|
||||
'structural-similarity',
|
||||
'euclidean-distance',
|
||||
'chi-square',
|
||||
'correlation',
|
||||
'intersection',
|
||||
'bhattacharyya-distance'
|
||||
]
|
||||
|
||||
def __init__(self, count, top_ks=(1, 3, 5)):
|
||||
self.top_ks = top_ks
|
||||
self.hits = {k: {metric: 0 for metric in AccuracyGatherer.all_metric_names} for k in top_ks}
|
||||
self.count = count
|
||||
|
||||
def print(self):
|
||||
for k in self.top_ks:
|
||||
all_metrics = {metric: self.hits[k][metric] / self.count for metric in AccuracyGatherer.all_metric_names}
|
||||
print(f'Top {k} matches results:')
|
||||
[print(f'\t{key}: {value * 100}%') for key, value in all_metrics.items()]
|
||||
|
||||
def for_results(self, results, test_label):
|
||||
top_results_all_metrics = {
|
||||
k: {m: get_top_results(results, m, k) for m in AccuracyGatherer.all_metric_names} for k in self.top_ks
|
||||
}
|
||||
for metric_name in AccuracyGatherer.all_metric_names:
|
||||
self.add_if_hit(top_results_all_metrics, test_label, metric_name)
|
||||
|
||||
def add_if_hit(self, results, test_label, metric_name):
|
||||
for k in self.top_ks:
|
||||
if any(map(lambda single_result: single_result['name'] == test_label, results[k][metric_name])):
|
||||
self.hits[k][metric_name] += 1
|
@ -1,54 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from yoloface import face_analysis
|
||||
|
||||
face_detector = face_analysis()
|
||||
|
||||
|
||||
def equalize_image(data: np.ndarray):
|
||||
data_hsv = cv2.cvtColor(data, cv2.COLOR_RGB2HSV)
|
||||
data_hsv[:, :, 2] = cv2.equalizeHist(data_hsv[:, :, 2])
|
||||
return cv2.cvtColor(data_hsv, cv2.COLOR_HSV2RGB)
|
||||
|
||||
|
||||
def find_face_bbox_yolo(data: np.ndarray):
|
||||
_, box, conf = face_detector.face_detection(frame_arr=data, frame_status=True, model='full')
|
||||
if len(box) < 1:
|
||||
return None, None
|
||||
return box, conf
|
||||
|
||||
|
||||
def find_face_bbox(data: np.ndarray):
|
||||
classifier_files = [
|
||||
'haarcascades/haarcascade_frontalface_default.xml',
|
||||
'haarcascades/haarcascade_frontalface_alt.xml',
|
||||
'haarcascades/haarcascade_frontalface_alt2.xml',
|
||||
'haarcascades/haarcascade_profileface.xml',
|
||||
'haarcascades/haarcascade_glasses.xml',
|
||||
'lbpcascade_animeface.xml',
|
||||
]
|
||||
data_equalized = equalize_image(data)
|
||||
data_gray = cv2.cvtColor(data_equalized, cv2.COLOR_RGB2GRAY)
|
||||
face_coords, conf = find_face_bbox_yolo(cv2.cvtColor(data_equalized, cv2.COLOR_RGB2BGR))
|
||||
if face_coords is not None:
|
||||
return face_coords[0]
|
||||
|
||||
for classifier in classifier_files:
|
||||
face_cascade = cv2.CascadeClassifier(classifier)
|
||||
face_coords = face_cascade.detectMultiScale(data_gray, 1.1, 3)
|
||||
if face_coords is not None:
|
||||
break
|
||||
return max(face_coords, key=lambda v: v[2]*v[3])
|
||||
|
||||
|
||||
def crop_face(data: np.ndarray, bounding_box) -> np.ndarray:
|
||||
x, y, w, h = bounding_box
|
||||
# Extending the boxes
|
||||
factor = 0.4
|
||||
x, y = round(x - factor * w), round(y - factor * h)
|
||||
w, h = round(w + factor * w * 2), round(h + factor * h * 2)
|
||||
y = max(y, 0)
|
||||
x = max(x, 0)
|
||||
|
||||
face = data[y:y + h, x:x + w]
|
||||
return face
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
12
helpers.py
12
helpers.py
@ -1,12 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def no_stdout(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
ret = func(*args, **kwargs)
|
||||
sys.stdout = old_stdout
|
||||
return ret
|
||||
return wrapper
|
@ -5,11 +5,7 @@ import cv2 as cv
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_source(filename: str) -> np.ndarray:
|
||||
return cv.imread(filename)[..., ::-1]
|
||||
|
||||
|
||||
def load_data(input_dir):
|
||||
def load_data(input_dir, newSize=(64,64)):
|
||||
image_path = Path(input_dir)
|
||||
file_names = os.listdir(image_path)
|
||||
categories_name = []
|
||||
@ -31,7 +27,8 @@ def load_data(input_dir):
|
||||
|
||||
for n in file_names:
|
||||
p = image_path / n
|
||||
img = load_source(str(p)) # zwraca ndarry postaci xSize x ySize x colorDepth
|
||||
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
|
||||
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
|
||||
test_img.append(img)
|
||||
labels.append(n)
|
||||
|
||||
|
127
main.py
127
main.py
@ -1,16 +1,11 @@
|
||||
import argparse
|
||||
import sys
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from metrics import histogram_comparison, structural_similarity_index, euclidean_distance, AccuracyGatherer
|
||||
|
||||
from face_detect import find_face_bbox, crop_face
|
||||
from helpers import no_stdout
|
||||
from load_test_data import load_data, load_source
|
||||
from metrics import get_top_results
|
||||
from plots import plot_two_images, plot_results
|
||||
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
|
||||
from load_test_data import load_data
|
||||
|
||||
# Allows imports from the style transfer submodule
|
||||
sys.path.append('DCT-Net')
|
||||
@ -18,10 +13,32 @@ sys.path.append('DCT-Net')
|
||||
from source.cartoonize import Cartoonizer
|
||||
|
||||
|
||||
anime_transfer = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
|
||||
def load_source(filename: str) -> np.ndarray:
|
||||
return cv2.imread(filename)[..., ::-1]
|
||||
|
||||
|
||||
def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
|
||||
def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray:
|
||||
data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
|
||||
face_cascade = cv2.CascadeClassifier(classifier_file)
|
||||
face = face_cascade.detectMultiScale(data_gray, 1.1, 3)
|
||||
face = max(face, key=len)
|
||||
x, y, w, h = face
|
||||
face = data[y:y + h, x:x + w]
|
||||
return face
|
||||
|
||||
|
||||
def plot_two_images(a: np.ndarray, b: np.ndarray):
|
||||
plt.figure(figsize=[10, 10])
|
||||
plt.subplot(121)
|
||||
plt.imshow(a)
|
||||
plt.title("A")
|
||||
plt.subplot(122)
|
||||
plt.imshow(b)
|
||||
plt.title("B")
|
||||
plt.show()
|
||||
|
||||
|
||||
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
|
||||
all_metrics = []
|
||||
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
|
||||
current_result = {
|
||||
@ -31,7 +48,7 @@ def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset:
|
||||
# TODO: Use a different face detection method for anime images
|
||||
# anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
|
||||
anime_face = anime_image
|
||||
source_rescaled = cv2.resize(source_image, anime_face.shape[:2])
|
||||
source_rescaled = cv2.resize(source, anime_face.shape[:2])
|
||||
if verbose:
|
||||
plot_two_images(anime_face, source_rescaled)
|
||||
current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
|
||||
@ -42,73 +59,57 @@ def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset:
|
||||
return all_metrics
|
||||
|
||||
|
||||
@no_stdout
|
||||
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
|
||||
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
|
||||
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
|
||||
|
||||
|
||||
def transfer_to_anime(img: np.ndarray):
|
||||
model_out = anime_transfer.cartoonize(img).astype(np.uint8)
|
||||
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
|
||||
algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
|
||||
return algo.cartoonize(img).astype(np.uint8)
|
||||
|
||||
|
||||
def similarity_to_anime(source_image, anime_faces_set, debug=False):
|
||||
try:
|
||||
source_face_bbox = find_face_bbox(source_image)
|
||||
except ValueError:
|
||||
return None
|
||||
source_anime = transfer_to_anime(source_image)
|
||||
source_face_anime = crop_face(source_anime, source_face_bbox)
|
||||
|
||||
if debug:
|
||||
source_image_with_box = source_image.copy()
|
||||
x, y, w, h = source_face_bbox
|
||||
cv2.rectangle(source_image_with_box, (x, y), (x + w, y + h), (255, 0, 0), 2)
|
||||
plt.figure(figsize=[12, 4])
|
||||
plt.subplot(131)
|
||||
plt.imshow(source_image_with_box)
|
||||
plt.subplot(132)
|
||||
plt.imshow(source_anime)
|
||||
plt.subplot(133)
|
||||
plt.imshow(source_face_anime)
|
||||
plt.show()
|
||||
|
||||
return compare_with_anime_characters(source_face_anime, anime_faces_set, verbose=debug)
|
||||
|
||||
|
||||
def validate(test_set, anime_faces_set):
|
||||
def validate(test_set, anime_faces_set, metric='correlation', top_n=1):
|
||||
all_entries = len(test_set['values'])
|
||||
accuracy = AccuracyGatherer(all_entries)
|
||||
correct = 0
|
||||
for test_image, test_label in zip(test_set['values'], test_set['labels']):
|
||||
test_results = similarity_to_anime(test_image, anime_faces_set)
|
||||
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric, top_n)
|
||||
if any(map(lambda single_result: single_result['name'] == test_label, output)):
|
||||
correct += 1
|
||||
|
||||
if test_results is None:
|
||||
print(f"cannot find face for {test_label}")
|
||||
all_entries -= 1
|
||||
continue
|
||||
|
||||
accuracy.for_results(test_results, test_label)
|
||||
|
||||
accuracy.count = all_entries
|
||||
accuracy.print()
|
||||
accuracy = correct / all_entries
|
||||
print(f'Accuracy using {metric}: {accuracy * 100}%')
|
||||
return accuracy
|
||||
|
||||
|
||||
def validate_all(test_set, anime_faces_set, metric='correlation', top_n=1):
|
||||
validate(test_set, anime_faces_set, 'structural-similarity', top_n)
|
||||
validate(test_set, anime_faces_set, 'euclidean-distance', top_n)
|
||||
validate(test_set, anime_faces_set, 'chi-square', top_n)
|
||||
validate(test_set, anime_faces_set, 'correlation', top_n)
|
||||
validate(test_set, anime_faces_set, 'intersection', top_n)
|
||||
validate(test_set, anime_faces_set, 'bhattacharyya-distance', top_n)
|
||||
|
||||
def main():
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-v', '--validate_only')
|
||||
args = parser.parse_args()
|
||||
anime_faces_set = load_data('data/croped_anime_faces')
|
||||
anime_faces_set = load_data('data/images')
|
||||
|
||||
if args.validate_only:
|
||||
print('Validating')
|
||||
test_set = load_data('test_set')
|
||||
validate(test_set, anime_faces_set)
|
||||
print('Top 1 matches results:')
|
||||
validate_all(test_set, anime_faces_set, 'structural-similarity', 1)
|
||||
print('Top 3 matches results:')
|
||||
validate_all(test_set, anime_faces_set, 'structural-similarity', 3)
|
||||
print('Top 5 matches results:')
|
||||
validate_all(test_set, anime_faces_set, 'structural-similarity', 5)
|
||||
exit(0)
|
||||
|
||||
source = load_source('test_set/Ayanokouji, Kiyotaka.jpg')
|
||||
results = similarity_to_anime(source, anime_faces_set)
|
||||
method = 'correlation'
|
||||
top_results = get_top_results(results, count=4, metric=method)
|
||||
print(top_results)
|
||||
plot_results(source, transfer_to_anime(source), top_results, anime_faces_set, method)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
source = load_source('UAM-Andre.jpg')
|
||||
source_anime = transfer_to_anime(source)
|
||||
source_face_anime = find_and_crop_face(source_anime)
|
||||
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
|
||||
print(get_top_results(results, count=5))
|
||||
|
45
plots.py
45
plots.py
@ -1,45 +0,0 @@
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt, gridspec
|
||||
|
||||
|
||||
def plot_two_images(a: np.ndarray, b: np.ndarray):
|
||||
plt.figure(figsize=[10, 10])
|
||||
plt.subplot(121)
|
||||
plt.imshow(a)
|
||||
plt.title("A")
|
||||
plt.subplot(122)
|
||||
plt.imshow(b)
|
||||
plt.title("B")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_results(source, source_anime, results, anime_faces_set, method):
|
||||
cols = len(results)
|
||||
plt.figure(figsize=[3*cols, 7])
|
||||
gs = gridspec.GridSpec(2, cols)
|
||||
|
||||
plt.subplot(gs[0, cols // 2 - 1])
|
||||
plt.imshow(source)
|
||||
plt.title('Your image')
|
||||
plt.axis('off')
|
||||
|
||||
plt.subplot(gs[0, cols // 2])
|
||||
plt.imshow(source_anime)
|
||||
plt.title('Your image in Anime style')
|
||||
plt.axis('off')
|
||||
|
||||
plt.figtext(0.5, 0.525, "Predictions", ha="center", va="top", fontsize=16)
|
||||
|
||||
for idx, prediction in enumerate(results):
|
||||
result_img = anime_faces_set['values'][anime_faces_set['labels'].index(prediction['name'])]
|
||||
plt.subplot(gs[1, idx])
|
||||
plt.imshow(result_img, interpolation='bicubic')
|
||||
plt.title(f'{prediction["name"].partition(".")[0]}, score={str(round(prediction["score"], 4))}')
|
||||
plt.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
plt.figtext(0.5, 0.01, f"Metric: {method}", ha="center", va="bottom", fontsize=12)
|
||||
plt.subplots_adjust(wspace=0, hspace=0.1)
|
||||
|
||||
plt.show()
|
@ -8,6 +8,4 @@ lxml==4.9.2
|
||||
opencv-python==4.7.0.68
|
||||
torch==1.13.1
|
||||
matplotlib==3.6.3
|
||||
scikit-image==0.19.3
|
||||
yoloface==0.0.4
|
||||
ipython==8.9.0
|
||||
scikit-image==0.19.3
|
Loading…
Reference in New Issue
Block a user