create test ds

This commit is contained in:
mszmyd 2024-05-06 00:07:09 +02:00
parent c70553ec7c
commit 7849d9ad51
2 changed files with 46 additions and 14 deletions

View File

@ -1,7 +1,10 @@
.PHONY: download-dataset sobel-dataset
.PHONY: download-dataset resize-dataset sobel-dataset
download-dataset:
python3 ./file_manager/data_manager.py --download
resize-dataset:
python3 ./file_manager/data_manager.py --resize --shape 64 64 --source "original_dataset"
sobel-dataset:
python3 ./file_manager/data_manager.py --sobel
python3 ./file_manager/data_manager.py --sobel --source "resized_dataset"

View File

@ -1,23 +1,28 @@
import glob
import shutil
import cv2
from zipfile import ZipFile
import os
import wget
import argparse
import glob
import os
import shutil
from pathlib import Path
from zipfile import ZipFile
import cv2
import wget
main_path = Path("data/")
path_to_train_and_valid = main_path / "%s/**/*.*"
path_to_test_dataset = main_path / "test"
original_dataset_name = "original_dataset"
parser = argparse.ArgumentParser()
parser.add_argument("--download", action="store_true",
help="Download the data")
parser.add_argument("--resize", action="store_true",
help="Resize the dataset")
parser.add_argument("--shape", type=int, nargs="+", default=(64, 64),
help="Shape of the resized images. Applied only for resize option. Default: (64, 64)")
parser.add_argument("--sobel", action="store_true",
help="Apply Sobel filter to the dataset")
parser.add_argument("--source", type=str, default="original_dataset",
help="Name of the source dataset. Applied for all arguments except download. Default: original_dataset")
args = parser.parse_args()
@ -45,18 +50,41 @@ class DataManager:
shutil.rmtree(
full_path_to_extract / "new plant diseases dataset(augmented)"
)
shutil.rmtree(full_path_to_extract / "test")
self.get_test_ds_from_validation()
def write_image(self, image, path):
os.makedirs(path.rsplit('/', 1)[0], exist_ok=True)
cv2.imwrite(path, image)
def resize_dataset(self, source_dataset_name, width, height):
def get_test_ds_from_validation(self, files_per_category: int = 2):
path_to_extract = main_path / original_dataset_name
valid_ds = glob.glob(str(path_to_extract / "valid/*/*"))
category_dirs = set([category_dir.split("/")[-2]
for category_dir in valid_ds])
category_lists = {category: [] for category in category_dirs}
for file_path in valid_ds:
category = file_path.split("/")[-2]
category_lists[category].append(file_path)
test_dir = path_to_extract / "test"
if not os.path.exists(test_dir):
os.makedirs(test_dir, exist_ok=True)
for category, files in category_lists.items():
os.makedirs(test_dir / category, exist_ok=True)
files.sort()
for file in files[:files_per_category]:
shutil.move(file, test_dir / category)
def resize_dataset(self, source_dataset_name, shape):
dataset_name = "resized_dataset"
if not os.path.exists(main_path / dataset_name):
for file in glob.glob(str(path_to_train_and_valid) % source_dataset_name, recursive=True):
path_to_file = file.replace("\\", "/")
image = cv2.imread(path_to_file)
image = cv2.resize(image, (width, height))
image = cv2.resize(image, shape)
new_path = path_to_file.replace(
source_dataset_name, dataset_name)
self.write_image(image, new_path)
@ -78,6 +106,7 @@ if __name__ == "__main__":
if args.download:
data_manager.download_data()
data_manager.unzip_data("archive.zip", original_dataset_name)
data_manager.resize_dataset(original_dataset_name, 64, 64)
if args.resize:
data_manager.resize_dataset(args.source, tuple(args.shape))
if args.sobel:
data_manager.sobelx("resized_dataset")
data_manager.sobelx(args.source)