60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import sys
|
|
import cv2
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from comparisons import histogram_comparison, structural_similarity_index
|
|
|
|
# Allows imports from the style transfer submodule
|
|
sys.path.append('DCT-Net')
|
|
|
|
from source.cartoonize import Cartoonizer
|
|
|
|
|
|
def load_source(filename: str) -> np.ndarray:
|
|
return cv2.imread(filename)[..., ::-1]
|
|
|
|
|
|
def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray:
|
|
data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
|
|
face_cascade = cv2.CascadeClassifier(classifier_file)
|
|
face = face_cascade.detectMultiScale(data_gray, 1.1, 3)
|
|
face = max(face, key=len)
|
|
x, y, w, h = face
|
|
face = data[y:y + h, x:x + w]
|
|
return face
|
|
|
|
|
|
def plot_two_images(a: np.ndarray, b: np.ndarray):
|
|
plt.figure(figsize=[10, 10])
|
|
plt.subplot(121)
|
|
plt.imshow(a)
|
|
plt.title("A")
|
|
plt.subplot(122)
|
|
plt.imshow(b)
|
|
plt.title("B")
|
|
plt.show()
|
|
|
|
|
|
def compare_with_anime_characters(data: np.ndarray) -> int:
|
|
# Example will be one face from anime dataset
|
|
example = load_source('data/images/Aisaka, Taiga.jpg')
|
|
# TODO: Use a different face detection method for anime images
|
|
example_face = find_and_crop_face(example, 'haarcascades/lbpcascade_animeface.xml')
|
|
data_rescaled = cv2.resize(data, example_face.shape[:2])
|
|
plot_two_images(example_face, data_rescaled)
|
|
print(histogram_comparison(data_rescaled, example_face))
|
|
print(f'structural-similarity: {structural_similarity_index(data_rescaled, example_face)}')
|
|
|
|
|
|
def transfer_to_anime(img: np.ndarray):
|
|
algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
|
|
return algo.cartoonize(img).astype(np.uint8)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
source = load_source('UAM-Andre.jpg')
|
|
source_anime = transfer_to_anime(source)
|
|
source_face_anime = find_and_crop_face(source_anime)
|
|
print(compare_with_anime_characters(source_face_anime))
|