implement load_data
This commit is contained in:
parent
15cd61d2d2
commit
dac16b130b
46
load_test_data.py
Normal file
46
load_test_data.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from skimage.io import imread
|
||||||
|
import cv2 as cv
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def load_data(input_dir, newSize=(64,64)):
|
||||||
|
image_path = Path(input_dir)
|
||||||
|
file_names = os.listdir(image_path)
|
||||||
|
categories_name = []
|
||||||
|
categories_count=[]
|
||||||
|
count = 0
|
||||||
|
n = file_names[0]
|
||||||
|
for name in file_names:
|
||||||
|
if name != n:
|
||||||
|
categories_count.append(count)
|
||||||
|
n = name
|
||||||
|
count = 1
|
||||||
|
else:
|
||||||
|
count += 1
|
||||||
|
if not name in categories_name:
|
||||||
|
categories_name.append(name)
|
||||||
|
categories_count.append(count)
|
||||||
|
test_img = []
|
||||||
|
labels=[]
|
||||||
|
|
||||||
|
for n in file_names:
|
||||||
|
p = image_path / n
|
||||||
|
img = imread(p) # zwraca ndarry postaci xSize x ySize x colorDepth
|
||||||
|
img = cv.resize(img, newSize, interpolation=cv.INTER_AREA) # zwraca ndarray
|
||||||
|
img = img / 255 # type: ignore #normalizacja
|
||||||
|
test_img.append(img)
|
||||||
|
labels.append(n)
|
||||||
|
|
||||||
|
X={}
|
||||||
|
X["values"] = np.array(test_img)
|
||||||
|
X["categories_name"] = categories_name
|
||||||
|
X["categories_count"] = categories_count
|
||||||
|
X["labels"]=labels
|
||||||
|
return X
|
||||||
|
|
||||||
|
data = load_data('test_set')
|
||||||
|
print(data['categories_name'])
|
||||||
|
print(data['categories_count'])
|
||||||
|
print(data['labels'])
|
||||||
|
print(list(data["values"]))
|
Loading…
Reference in New Issue
Block a user