Compare commits

..

No commits in common. "main" and "full-dataset-comparison" have entirely different histories.

13 changed files with 60 additions and 88913 deletions

1
.gitignore vendored
View File

@ -1,6 +1,5 @@
data
.idea
.yoloface
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -1,40 +0,0 @@
# wko_anime-face-similarity
Projekt przygotowany na zajęcia z widzenia komputerowego.
Rozpoznaje twarz na zdjęciu wejściowym i dokonując transferu stylu do anime, porównuje zdjęcie ze zbiorem postaci
z anime i wskazuje podobieństwa według wybranych metryk.
## Instalacja
1. Pobranie submodułów:
```shell
$ git submodule update --init
```
2. Instalacja zależności:
* Windows/Linux
```shell
$ pip install -r requirements.txt
```
* MacOS
```shell
$ pip install -r requirements-osx.txt
```
3. Konfiguracja DCT-Netu (anime style transfer)
```shell
$ cd DCT-Net && python download.py
```
4. Pobranie datasetu twarzy postaci z anime (MyAnimeList)
```shell
$ python scrape_data.py
```
## Uruchomienie
Na tę chwilę zdjęcie poddawane porównaniu to `UAM-Andre.jpg`
```shell
$ python main.py
```
### Walidacja
Do walidacji metryk na postawie testowego datasetu z cosplayerami (`test_set`) uruchamiamy
```shell
$ python --validate_only 1
```

View File

@ -40,42 +40,3 @@ def euclidean_distance(data_a: np.ndarray, data_b: np.ndarray) -> float:
result += (histogram_a[i] - histogram_b[i]) ** 2
i += 1
return result[0] ** (1 / 2)
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
class AccuracyGatherer:
all_metric_names = [
'structural-similarity',
'euclidean-distance',
'chi-square',
'correlation',
'intersection',
'bhattacharyya-distance'
]
def __init__(self, count, top_ks=(1, 3, 5)):
self.top_ks = top_ks
self.hits = {k: {metric: 0 for metric in AccuracyGatherer.all_metric_names} for k in top_ks}
self.count = count
def print(self):
for k in self.top_ks:
all_metrics = {metric: self.hits[k][metric] / self.count for metric in AccuracyGatherer.all_metric_names}
print(f'Top {k} matches results:')
[print(f'\t{key}: {value * 100}%') for key, value in all_metrics.items()]
def for_results(self, results, test_label):
top_results_all_metrics = {
k: {m: get_top_results(results, m, k) for m in AccuracyGatherer.all_metric_names} for k in self.top_ks
}
for metric_name in AccuracyGatherer.all_metric_names:
self.add_if_hit(top_results_all_metrics, test_label, metric_name)
def add_if_hit(self, results, test_label, metric_name):
for k in self.top_ks:
if any(map(lambda single_result: single_result['name'] == test_label, results[k][metric_name])):
self.hits[k][metric_name] += 1

View File

@ -1,54 +0,0 @@
import cv2
import numpy as np
from yoloface import face_analysis
face_detector = face_analysis()
def equalize_image(data: np.ndarray):
data_hsv = cv2.cvtColor(data, cv2.COLOR_RGB2HSV)
data_hsv[:, :, 2] = cv2.equalizeHist(data_hsv[:, :, 2])
return cv2.cvtColor(data_hsv, cv2.COLOR_HSV2RGB)
def find_face_bbox_yolo(data: np.ndarray):
_, box, conf = face_detector.face_detection(frame_arr=data, frame_status=True, model='full')
if len(box) < 1:
return None, None
return box, conf
def find_face_bbox(data: np.ndarray):
classifier_files = [
'haarcascades/haarcascade_frontalface_default.xml',
'haarcascades/haarcascade_frontalface_alt.xml',
'haarcascades/haarcascade_frontalface_alt2.xml',
'haarcascades/haarcascade_profileface.xml',
'haarcascades/haarcascade_glasses.xml',
'lbpcascade_animeface.xml',
]
data_equalized = equalize_image(data)
data_gray = cv2.cvtColor(data_equalized, cv2.COLOR_RGB2GRAY)
face_coords, conf = find_face_bbox_yolo(cv2.cvtColor(data_equalized, cv2.COLOR_RGB2BGR))
if face_coords is not None:
return face_coords[0]
for classifier in classifier_files:
face_cascade = cv2.CascadeClassifier(classifier)
face_coords = face_cascade.detectMultiScale(data_gray, 1.1, 3)
if face_coords is not None:
break
return max(face_coords, key=lambda v: v[2]*v[3])
def crop_face(data: np.ndarray, bounding_box) -> np.ndarray:
x, y, w, h = bounding_box
# Extending the boxes
factor = 0.4
x, y = round(x - factor * w), round(y - factor * h)
w, h = round(w + factor * w * 2), round(h + factor * h * 2)
y = max(y, 0)
x = max(x, 0)
face = data[y:y + h, x:x + w]
return face

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +0,0 @@
import os
import sys
def no_stdout(func):
def wrapper(*args, **kwargs):
old_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
ret = func(*args, **kwargs)
sys.stdout = old_stdout
return ret
return wrapper

View File

@ -5,11 +5,7 @@ import cv2 as cv
from pathlib import Path
def load_source(filename: str) -> np.ndarray:
return cv.imread(filename)[..., ::-1]
def load_data(input_dir):
def load_data(input_dir, newSize=(64,64)):
image_path = Path(input_dir)
file_names = os.listdir(image_path)
categories_name = []
@ -31,7 +27,8 @@ def load_data(input_dir):
for n in file_names:
p = image_path / n
img = load_source(str(p)) # zwraca ndarry postaci xSize x ySize x colorDepth
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
test_img.append(img)
labels.append(n)

120
main.py
View File

@ -1,16 +1,11 @@
import argparse
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from metrics import histogram_comparison, structural_similarity_index, euclidean_distance, AccuracyGatherer
from face_detect import find_face_bbox, crop_face
from helpers import no_stdout
from load_test_data import load_data, load_source
from metrics import get_top_results
from plots import plot_two_images, plot_results
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
from load_test_data import load_data
# Allows imports from the style transfer submodule
sys.path.append('DCT-Net')
@ -18,10 +13,32 @@ sys.path.append('DCT-Net')
from source.cartoonize import Cartoonizer
anime_transfer = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
def load_source(filename: str) -> np.ndarray:
return cv2.imread(filename)[..., ::-1]
def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray:
data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
face_cascade = cv2.CascadeClassifier(classifier_file)
face = face_cascade.detectMultiScale(data_gray, 1.1, 3)
face = max(face, key=len)
x, y, w, h = face
face = data[y:y + h, x:x + w]
return face
def plot_two_images(a: np.ndarray, b: np.ndarray):
plt.figure(figsize=[10, 10])
plt.subplot(121)
plt.imshow(a)
plt.title("A")
plt.subplot(122)
plt.imshow(b)
plt.title("B")
plt.show()
def compare_with_anime_characters(source: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
all_metrics = []
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
current_result = {
@ -31,7 +48,7 @@ def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset:
# TODO: Use a different face detection method for anime images
# anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
anime_face = anime_image
source_rescaled = cv2.resize(source_image, anime_face.shape[:2])
source_rescaled = cv2.resize(source, anime_face.shape[:2])
if verbose:
plot_two_images(anime_face, source_rescaled)
current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
@ -42,73 +59,48 @@ def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset:
return all_metrics
@no_stdout
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
def transfer_to_anime(img: np.ndarray):
model_out = anime_transfer.cartoonize(img).astype(np.uint8)
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
return algo.cartoonize(img).astype(np.uint8)
def similarity_to_anime(source_image, anime_faces_set, debug=False):
try:
source_face_bbox = find_face_bbox(source_image)
except ValueError:
return None
source_anime = transfer_to_anime(source_image)
source_face_anime = crop_face(source_anime, source_face_bbox)
if debug:
source_image_with_box = source_image.copy()
x, y, w, h = source_face_bbox
cv2.rectangle(source_image_with_box, (x, y), (x + w, y + h), (255, 0, 0), 2)
plt.figure(figsize=[12, 4])
plt.subplot(131)
plt.imshow(source_image_with_box)
plt.subplot(132)
plt.imshow(source_anime)
plt.subplot(133)
plt.imshow(source_face_anime)
plt.show()
return compare_with_anime_characters(source_face_anime, anime_faces_set, verbose=debug)
def validate(test_set, anime_faces_set):
def validate(test_set, anime_faces_set, metric='correlation'):
all_entries = len(test_set['values'])
accuracy = AccuracyGatherer(all_entries)
correct = 0
for test_image, test_label in zip(test_set['values'], test_set['labels']):
test_results = similarity_to_anime(test_image, anime_faces_set)
output = get_top_results(compare_with_anime_characters(test_image, anime_faces_set), metric)[0]['name']
if output == test_label:
correct += 1
if test_results is None:
print(f"cannot find face for {test_label}")
all_entries -= 1
continue
accuracy.for_results(test_results, test_label)
accuracy.count = all_entries
accuracy.print()
accuracy = correct / all_entries
print(f'Accuracy using {metric}: {accuracy * 100}%')
return accuracy
def main():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--validate_only')
args = parser.parse_args()
anime_faces_set = load_data('data/croped_anime_faces')
anime_faces_set = load_data('data/images')
if args.validate_only:
print('Validating')
test_set = load_data('test_set')
validate(test_set, anime_faces_set)
validate(test_set, anime_faces_set, 'structural-similarity')
validate(test_set, anime_faces_set, 'euclidean-distance')
validate(test_set, anime_faces_set, 'chi-square')
validate(test_set, anime_faces_set, 'correlation')
validate(test_set, anime_faces_set, 'intersection')
validate(test_set, anime_faces_set, 'bhattacharyya-distance')
exit(0)
source = load_source('test_set/Ayanokouji, Kiyotaka.jpg')
results = similarity_to_anime(source, anime_faces_set)
method = 'correlation'
top_results = get_top_results(results, count=4, metric=method)
print(top_results)
plot_results(source, transfer_to_anime(source), top_results, anime_faces_set, method)
if __name__ == '__main__':
main()
source = load_source('UAM-Andre.jpg')
source_anime = transfer_to_anime(source)
source_face_anime = find_and_crop_face(source_anime)
results = compare_with_anime_characters(source_face_anime, anime_faces_set)
print(get_top_results(results, count=5))

View File

@ -1,45 +0,0 @@
import numpy as np
from matplotlib import pyplot as plt, gridspec
def plot_two_images(a: np.ndarray, b: np.ndarray):
plt.figure(figsize=[10, 10])
plt.subplot(121)
plt.imshow(a)
plt.title("A")
plt.subplot(122)
plt.imshow(b)
plt.title("B")
plt.show()
def plot_results(source, source_anime, results, anime_faces_set, method):
cols = len(results)
plt.figure(figsize=[3*cols, 7])
gs = gridspec.GridSpec(2, cols)
plt.subplot(gs[0, cols // 2 - 1])
plt.imshow(source)
plt.title('Your image')
plt.axis('off')
plt.subplot(gs[0, cols // 2])
plt.imshow(source_anime)
plt.title('Your image in Anime style')
plt.axis('off')
plt.figtext(0.5, 0.525, "Predictions", ha="center", va="top", fontsize=16)
for idx, prediction in enumerate(results):
result_img = anime_faces_set['values'][anime_faces_set['labels'].index(prediction['name'])]
plt.subplot(gs[1, idx])
plt.imshow(result_img, interpolation='bicubic')
plt.title(f'{prediction["name"].partition(".")[0]}, score={str(round(prediction["score"], 4))}')
plt.axis('off')
plt.tight_layout()
plt.figtext(0.5, 0.01, f"Metric: {method}", ha="center", va="bottom", fontsize=12)
plt.subplots_adjust(wspace=0, hspace=0.1)
plt.show()

View File

@ -1,11 +0,0 @@
tensorflow-macos==2.11.0
easydict==1.10
numpy==1.23.1
modelscope==1.1.3
requests==2.28.2
beautifulsoup4==4.11.1
lxml==4.9.2
opencv-python==4.7.0.68
torch==1.13.1
matplotlib==3.6.3
scikit-image==0.19.3

View File

@ -8,6 +8,4 @@ lxml==4.9.2
opencv-python==4.7.0.68
torch==1.13.1
matplotlib==3.6.3
scikit-image==0.19.3
yoloface==0.0.4
ipython==8.9.0
scikit-image==0.19.3