Compare against the entire anime faces dataset
This commit is contained in:
parent
b3bfa970c7
commit
49e337e5e9
@ -29,7 +29,6 @@ def load_data(input_dir, newSize=(64,64)):
|
|||||||
p = image_path / n
|
p = image_path / n
|
||||||
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
|
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
|
||||||
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
|
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
|
||||||
img = img / 255 # type: ignore #normalizacja
|
|
||||||
test_img.append(img)
|
test_img.append(img)
|
||||||
labels.append(n)
|
labels.append(n)
|
||||||
|
|
||||||
|
37
main.py
37
main.py
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
|
from comparisons import histogram_comparison, structural_similarity_index, euclidean_distance
|
||||||
|
from load_test_data import load_data
|
||||||
|
|
||||||
# Allows imports from the style transfer submodule
|
# Allows imports from the style transfer submodule
|
||||||
sys.path.append('DCT-Net')
|
sys.path.append('DCT-Net')
|
||||||
@ -36,16 +37,31 @@ def plot_two_images(a: np.ndarray, b: np.ndarray):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def compare_with_anime_characters(data: np.ndarray) -> int:
|
def compare_with_anime_characters(source: np.ndarray, verbose=False) -> list[dict]:
|
||||||
# Example will be one face from anime dataset
|
dataset = load_data('data/images')
|
||||||
example = load_source('data/images/Aisaka, Taiga.jpg')
|
all_metrics = []
|
||||||
|
for anime_image, label in zip(dataset['values'], dataset['labels']):
|
||||||
|
current_result = {
|
||||||
|
'name': label,
|
||||||
|
'metrics': {}
|
||||||
|
}
|
||||||
# TODO: Use a different face detection method for anime images
|
# TODO: Use a different face detection method for anime images
|
||||||
example_face = find_and_crop_face(example, 'haarcascades/lbpcascade_animeface.xml')
|
# anime_face = find_and_crop_face(anime_image, 'haarcascades/lbpcascade_animeface.xml')
|
||||||
data_rescaled = cv2.resize(data, example_face.shape[:2])
|
anime_face = anime_image
|
||||||
plot_two_images(example_face, data_rescaled)
|
source_rescaled = cv2.resize(source, anime_face.shape[:2])
|
||||||
print(histogram_comparison(data_rescaled, example_face))
|
if verbose:
|
||||||
print(f'structural-similarity: {structural_similarity_index(data_rescaled, example_face)}')
|
plot_two_images(anime_face, source_rescaled)
|
||||||
print(f'euclidean-distance: {euclidean_distance(data_rescaled, example_face)}')
|
current_result['metrics'] = histogram_comparison(source_rescaled, anime_face)
|
||||||
|
current_result['metrics']['structural-similarity'] = structural_similarity_index(source_rescaled, anime_face)
|
||||||
|
current_result['metrics']['euclidean-distance'] = euclidean_distance(source_rescaled, anime_face)
|
||||||
|
all_metrics.append(current_result)
|
||||||
|
|
||||||
|
return all_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def get_top_results(all_metrics: list[dict], metric='correlation', count=1):
|
||||||
|
all_metrics.sort(reverse=True, key=lambda item: item['metrics'][metric])
|
||||||
|
return list(map(lambda item: {'name': item['name'], 'score': item['metrics'][metric]}, all_metrics[:count]))
|
||||||
|
|
||||||
|
|
||||||
def transfer_to_anime(img: np.ndarray):
|
def transfer_to_anime(img: np.ndarray):
|
||||||
@ -57,4 +73,5 @@ if __name__ == '__main__':
|
|||||||
source = load_source('UAM-Andre.jpg')
|
source = load_source('UAM-Andre.jpg')
|
||||||
source_anime = transfer_to_anime(source)
|
source_anime = transfer_to_anime(source)
|
||||||
source_face_anime = find_and_crop_face(source_anime)
|
source_face_anime = find_and_crop_face(source_anime)
|
||||||
print(compare_with_anime_characters(source_face_anime))
|
results = compare_with_anime_characters(source_face_anime)
|
||||||
|
print(get_top_results(results, count=5))
|
||||||
|
Loading…
Reference in New Issue
Block a user