Histogram comparison method

This commit is contained in:
Marcin Kostrzewski 2023-01-29 22:43:45 +01:00
parent 8e4c805163
commit 15142d4c16
5 changed files with 6749 additions and 10 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
data data
.idea
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

25
comparisons.py Normal file
View File

@ -0,0 +1,25 @@
import cv2
import numpy as np
def histogram_comparison(data_a: np.ndarray, data_b: np.ndarray) -> dict:
hsv_a = cv2.cvtColor(data_a, cv2.COLOR_BGR2HSV)
hsv_b = cv2.cvtColor(data_b, cv2.COLOR_BGR2HSV)
histSize = [50, 60]
hue_ranges = [0, 180]
sat_ranges = [0, 256]
channels = [0, 1]
ranges = hue_ranges + sat_ranges
hist_a = cv2.calcHist([hsv_a], channels, None, histSize, ranges, accumulate=False)
cv2.normalize(hist_a, hist_a, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
hist_b = cv2.calcHist([hsv_b], channels, None, histSize, ranges, accumulate=False)
cv2.normalize(hist_b, hist_b, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
return {
'correlation': cv2.compareHist(hist_a, hist_b, 0),
'chi-square': cv2.compareHist(hist_a, hist_b, 1),
'intersection': cv2.compareHist(hist_a, hist_b, 2),
'bhattacharyya-distance': cv2.compareHist(hist_a, hist_b, 3),
}

File diff suppressed because it is too large Load Diff

37
main.py
View File

@ -1,6 +1,9 @@
import sys import sys
import cv2 import cv2
import numpy as np import numpy as np
import matplotlib.pyplot as plt
from comparisons import histogram_comparison
# Allows imports from the style transfer submodule # Allows imports from the style transfer submodule
sys.path.append('DCT-Net') sys.path.append('DCT-Net')
@ -12,28 +15,44 @@ def load_source(filename: str) -> np.ndarray:
return cv2.imread(filename)[..., ::-1] return cv2.imread(filename)[..., ::-1]
def find_and_crop_face(data: np.ndarray) -> np.ndarray: def find_and_crop_face(data: np.ndarray, classifier_file='haarcascades/haarcascade_frontalface_default.xml') -> np.ndarray:
data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) data_gray = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
face_cascade = cv2.CascadeClassifier('haarcascades/haarcascade_frontalface_default.xml') face_cascade = cv2.CascadeClassifier(classifier_file)
face = face_cascade.detectMultiScale(data_gray, 1.3, 4) face = face_cascade.detectMultiScale(data_gray, 1.1, 3)
face = max(face, key=len) face = max(face, key=len)
x, y, w, h = face x, y, w, h = face
face = data[y:y + h, x:x + w] face = data[y:y + h, x:x + w]
return face return face
def plot_two_images(a: np.ndarray, b: np.ndarray):
plt.figure(figsize=[10, 10])
plt.subplot(121)
plt.imshow(a)
plt.title("A")
plt.subplot(122)
plt.imshow(b)
plt.title("B")
plt.show()
def compare_with_anime_characters(data: np.ndarray) -> int: def compare_with_anime_characters(data: np.ndarray) -> int:
# TODO # Example will be one face from anime dataset
return 1 example = load_source('data/images/Aisaka, Taiga.jpg')
# TODO: Use a different face detection method for anime images
example_face = find_and_crop_face(example, 'haarcascades/lbpcascade_animeface.xml')
data_rescaled = cv2.resize(data, example_face.shape[:2])
plot_two_images(example_face, data_rescaled)
print(histogram_comparison(data_rescaled, example_face))
def transfer_to_anime(img: np.ndarray): def transfer_to_anime(img: np.ndarray):
algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models') algo = Cartoonizer(dataroot='DCT-Net/damo/cv_unet_person-image-cartoon_compound-models')
return algo.cartoonize(img) return algo.cartoonize(img).astype(np.uint8)
if __name__ == '__main__': if __name__ == '__main__':
source = load_source('input.png') source = load_source('UAM-Andre.jpg')
source_face = find_and_crop_face(source) source_anime = transfer_to_anime(source)
source_face_anime = transfer_to_anime(source) source_face_anime = find_and_crop_face(source_anime)
print(compare_with_anime_characters(source_face_anime)) print(compare_with_anime_characters(source_face_anime))

View File

@ -6,4 +6,5 @@ requests==2.28.2
beautifulsoup4==4.11.1 beautifulsoup4==4.11.1
lxml==4.9.2 lxml==4.9.2
opencv-python==4.7.0.68 opencv-python==4.7.0.68
torch==1.13.1 torch==1.13.1
matplotlib==3.6.3