2023-01-31 21:13:30 +01:00
|
|
|
import argparse
|
2023-01-15 12:40:35 +01:00
|
|
|
import sys
|
|
|
|
import cv2
|
2023-02-01 20:25:31 +01:00
|
|
|
import matplotlib.pyplot as plt
|
2023-01-15 12:40:35 +01:00
|
|
|
import numpy as np
|
2023-01-29 22:43:45 +01:00
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
from metrics import histogram_comparison, structural_similarity_index, euclidean_distance, AccuracyGatherer
|
|
|
|
|
|
|
|
from face_detect import find_face_bbox, crop_face
|
|
|
|
from helpers import no_stdout
|
|
|
|
from load_test_data import load_data, load_source
|
|
|
|
from metrics import get_top_results
|
2023-02-01 13:16:46 +01:00
|
|
|
from plots import plot_two_images, plot_results
|
2023-01-15 12:40:35 +01:00
|
|
|
|
2023-01-29 21:23:11 +01:00
|
|
|
# Allows imports from the style transfer submodule
|
|
|
|
sys.path.append('DCT-Net')
|
|
|
|
|
2023-01-15 12:40:35 +01:00
|
|
|
from source.cartoonize import Cartoonizer
|
|
|
|
|
2023-01-29 21:14:30 +01:00
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
anime_transfer = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
|
2023-01-15 12:40:35 +01:00
|
|
|
|
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
def compare_with_anime_characters(source_image: np.ndarray, anime_faces_dataset: dict, verbose=False) -> list[dict]:
|
2023-01-31 20:48:24 +01:00
|
|
|
all_metrics = []
|
2023-01-31 21:08:01 +01:00
|
|
|
for anime_image, label in zip(anime_faces_dataset['values'], anime_faces_dataset['labels']):
|
2023-01-31 20:48:24 +01:00
|
|
|
current_result = {
|
|
|
|
'name': label,
|
|
|
|
'metrics': {}
|
|
|
|
}
|
|
|
|
# TODO: Use a different face detection method for anime images
|
|
|
|
# anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
|
|
|
|
anime_face = anime_image
|
2023-02-01 19:55:12 +01:00
|
|
|
source_rescaled = cv2.resize(source_image, anime_face.shape[:2])
|
2023-01-31 20:48:24 +01:00
|
|
|
if verbose:
|
|
|
|
plot_two_images(anime_face, source_rescaled)
|
|
|
|
current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
|
|
|
|
current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face)
|
|
|
|
current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face)
|
|
|
|
all_metrics.append(current_result)
|
|
|
|
|
|
|
|
return all_metrics
|
|
|
|
|
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
@no_stdout
|
2023-01-29 15:17:39 +01:00
|
|
|
def transfer_to_anime(img: np.ndarray):
|
2023-02-01 19:55:12 +01:00
|
|
|
model_out = anime_transfer.cartoonize(img).astype(np.uint8)
|
2023-02-01 18:08:55 +01:00
|
|
|
return cv2.cvtColor(model_out, cv2.COLOR_BGR2RGB)
|
2023-01-15 12:40:35 +01:00
|
|
|
|
|
|
|
|
2023-02-01 20:25:31 +01:00
|
|
|
def similarity_to_anime(source_image, anime_faces_set, debug=True):
|
2023-02-01 19:55:12 +01:00
|
|
|
try:
|
|
|
|
source_face_bbox = find_face_bbox(source_image)
|
|
|
|
except ValueError:
|
|
|
|
return None
|
|
|
|
source_anime = transfer_to_anime(source_image)
|
|
|
|
source_face_anime = crop_face(source_anime, source_face_bbox)
|
2023-02-01 20:25:31 +01:00
|
|
|
|
|
|
|
if debug:
|
|
|
|
source_image_with_box = source_image.copy()
|
|
|
|
x, y, w, h = source_face_bbox
|
|
|
|
cv2.rectangle(source_image_with_box, (x, y), (x + w, y + h), (255, 0, 0), 2)
|
|
|
|
plt.figure(figsize=[12, 4])
|
|
|
|
plt.subplot(131)
|
|
|
|
plt.imshow(source_image_with_box)
|
|
|
|
plt.subplot(132)
|
|
|
|
plt.imshow(source_anime)
|
|
|
|
plt.subplot(133)
|
|
|
|
plt.imshow(source_face_anime)
|
|
|
|
plt.show()
|
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
return compare_with_anime_characters(source_face_anime, anime_faces_set)
|
|
|
|
|
|
|
|
|
|
|
|
def validate(test_set, anime_faces_set):
|
2023-01-31 21:08:01 +01:00
|
|
|
all_entries = len(test_set['values'])
|
2023-02-01 19:55:12 +01:00
|
|
|
accuracy = AccuracyGatherer(all_entries)
|
2023-01-31 21:08:01 +01:00
|
|
|
for test_image, test_label in zip(test_set['values'], test_set['labels']):
|
2023-02-01 19:55:12 +01:00
|
|
|
test_results = similarity_to_anime(test_image, anime_faces_set)
|
2023-02-01 13:47:51 +01:00
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
if test_results is None:
|
|
|
|
print(f"cannot find face for {test_label}")
|
|
|
|
all_entries -= 1
|
|
|
|
continue
|
2023-02-01 13:47:51 +01:00
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
accuracy.for_results(test_results, test_label)
|
|
|
|
|
|
|
|
accuracy.count = all_entries
|
|
|
|
accuracy.print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2023-01-31 21:13:30 +01:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('-v', '--validate_only')
|
|
|
|
args = parser.parse_args()
|
2023-02-01 19:55:12 +01:00
|
|
|
anime_faces_set = load_data('data/croped_anime_faces')
|
2023-01-31 21:08:01 +01:00
|
|
|
|
2023-01-31 21:13:30 +01:00
|
|
|
if args.validate_only:
|
|
|
|
print('Validating')
|
|
|
|
test_set = load_data('test_set')
|
2023-02-01 19:55:12 +01:00
|
|
|
validate(test_set, anime_faces_set)
|
2023-01-31 21:13:30 +01:00
|
|
|
exit(0)
|
|
|
|
|
2023-02-01 19:55:12 +01:00
|
|
|
source = load_source('test_set/Ayanokouji, Kiyotaka.jpg')
|
|
|
|
results = similarity_to_anime(source, anime_faces_set)
|
2023-02-01 13:16:46 +01:00
|
|
|
method = 'structural-similarity'
|
|
|
|
top_results = get_top_results(results, count=4, metric=method)
|
|
|
|
print(top_results)
|
2023-02-01 19:55:12 +01:00
|
|
|
plot_results(source, transfer_to_anime(source), top_results, anime_faces_set, method)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|