From 7849d9ad51af34db91cf53c7468cccab68a93cca Mon Sep 17 00:00:00 2001 From: mszmyd Date: Mon, 6 May 2024 00:07:09 +0200 Subject: [PATCH] create test ds --- Makefile | 7 +++-- file_manager/data_manager.py | 53 ++++++++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index 5734438..c500c69 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,10 @@ -.PHONY: download-dataset sobel-dataset +.PHONY: download-dataset resize-dataset sobel-dataset download-dataset: python3 ./file_manager/data_manager.py --download +resize-dataset: + python3 ./file_manager/data_manager.py --resize --shape 64 64 --source "original_dataset" + sobel-dataset: - python3 ./file_manager/data_manager.py --sobel \ No newline at end of file + python3 ./file_manager/data_manager.py --sobel --source "resized_dataset" \ No newline at end of file diff --git a/file_manager/data_manager.py b/file_manager/data_manager.py index 9136161..c26f9fb 100644 --- a/file_manager/data_manager.py +++ b/file_manager/data_manager.py @@ -1,23 +1,28 @@ -import glob -import shutil -import cv2 -from zipfile import ZipFile -import os -import wget import argparse +import glob +import os +import shutil from pathlib import Path +from zipfile import ZipFile + +import cv2 +import wget main_path = Path("data/") path_to_train_and_valid = main_path / "%s/**/*.*" -path_to_test_dataset = main_path / "test" original_dataset_name = "original_dataset" parser = argparse.ArgumentParser() parser.add_argument("--download", action="store_true", help="Download the data") +parser.add_argument("--resize", action="store_true", + help="Resize the dataset") +parser.add_argument("--shape", type=int, nargs="+", default=(64, 64), + help="Shape of the resized images. Applied only for resize option. Default: (64, 64)") parser.add_argument("--sobel", action="store_true", help="Apply Sobel filter to the dataset") - +parser.add_argument("--source", type=str, default="original_dataset", + help="Name of the source dataset. Applied for all arguments except download. Default: original_dataset") args = parser.parse_args() @@ -45,18 +50,41 @@ class DataManager: shutil.rmtree( full_path_to_extract / "new plant diseases dataset(augmented)" ) + shutil.rmtree(full_path_to_extract / "test") + self.get_test_ds_from_validation() def write_image(self, image, path): os.makedirs(path.rsplit('/', 1)[0], exist_ok=True) cv2.imwrite(path, image) - def resize_dataset(self, source_dataset_name, width, height): + def get_test_ds_from_validation(self, files_per_category: int = 2): + path_to_extract = main_path / original_dataset_name + valid_ds = glob.glob(str(path_to_extract / "valid/*/*")) + + category_dirs = set([category_dir.split("/")[-2] + for category_dir in valid_ds]) + category_lists = {category: [] for category in category_dirs} + for file_path in valid_ds: + category = file_path.split("/")[-2] + category_lists[category].append(file_path) + + test_dir = path_to_extract / "test" + if not os.path.exists(test_dir): + os.makedirs(test_dir, exist_ok=True) + + for category, files in category_lists.items(): + os.makedirs(test_dir / category, exist_ok=True) + files.sort() + for file in files[:files_per_category]: + shutil.move(file, test_dir / category) + + def resize_dataset(self, source_dataset_name, shape): dataset_name = "resized_dataset" if not os.path.exists(main_path / dataset_name): for file in glob.glob(str(path_to_train_and_valid) % source_dataset_name, recursive=True): path_to_file = file.replace("\\", "/") image = cv2.imread(path_to_file) - image = cv2.resize(image, (width, height)) + image = cv2.resize(image, shape) new_path = path_to_file.replace( source_dataset_name, dataset_name) self.write_image(image, new_path) @@ -78,6 +106,7 @@ if __name__ == "__main__": if args.download: data_manager.download_data() data_manager.unzip_data("archive.zip", original_dataset_name) - data_manager.resize_dataset(original_dataset_name, 64, 64) + if args.resize: + data_manager.resize_dataset(args.source, tuple(args.shape)) if args.sobel: - data_manager.sobelx("resized_dataset") + data_manager.sobelx(args.source)