feature/basic-model-setup #3

Merged
s495727 merged 9 commits from feature/basic-model-setup into main 2024-05-11 20:00:07 +02:00
2 changed files with 46 additions and 14 deletions
Showing only changes of commit 7fb9902340 - Show all commits

View File

@ -1,7 +1,10 @@
.PHONY: download-dataset sobel-dataset
.PHONY: download-dataset resize-dataset sobel-dataset
download-dataset:
python3 ./file_manager/data_manager.py --download
resize-dataset:
python3 ./file_manager/data_manager.py --resize --shape 64 64 --source "original_dataset"
sobel-dataset:
python3 ./file_manager/data_manager.py --sobel
python3 ./file_manager/data_manager.py --sobel --source "resized_dataset"

View File

@ -1,23 +1,28 @@
import glob
import shutil
import cv2
from zipfile import ZipFile
import os
import wget
import argparse
import glob
import os
import shutil
from pathlib import Path
from zipfile import ZipFile
import cv2
import wget
main_path = Path("data/")
path_to_train_and_valid = main_path / "%s/**/*.*"
path_to_test_dataset = main_path / "test"
original_dataset_name = "original_dataset"
parser = argparse.ArgumentParser()
parser.add_argument("--download", action="store_true",
help="Download the data")
parser.add_argument("--resize", action="store_true",
help="Resize the dataset")
parser.add_argument("--shape", type=int, nargs="+", default=(64, 64),
help="Shape of the resized images. Applied only for resize option. Default: (64, 64)")
parser.add_argument("--sobel", action="store_true",
help="Apply Sobel filter to the dataset")
parser.add_argument("--source", type=str, default="original_dataset",
help="Name of the source dataset. Applied for all arguments except download. Default: original_dataset")
args = parser.parse_args()
@ -45,18 +50,41 @@ class DataManager:
shutil.rmtree(
full_path_to_extract / "new plant diseases dataset(augmented)"
)
shutil.rmtree(full_path_to_extract / "test")
self.get_test_ds_from_validation()
def write_image(self, image, path):
os.makedirs(path.rsplit('/', 1)[0], exist_ok=True)
cv2.imwrite(path, image)
def resize_dataset(self, source_dataset_name, width, height):
def get_test_ds_from_validation(self, files_per_category: int = 2):
path_to_extract = main_path / original_dataset_name
valid_ds = glob.glob(str(path_to_extract / "valid/*/*"))
category_dirs = set([category_dir.split("/")[-2]
for category_dir in valid_ds])
category_lists = {category: [] for category in category_dirs}
for file_path in valid_ds:
category = file_path.split("/")[-2]
category_lists[category].append(file_path)
test_dir = path_to_extract / "test"
if not os.path.exists(test_dir):
os.makedirs(test_dir, exist_ok=True)
for category, files in category_lists.items():
os.makedirs(test_dir / category, exist_ok=True)
files.sort()
for file in files[:files_per_category]:
shutil.move(file, test_dir / category)
def resize_dataset(self, source_dataset_name, shape):
dataset_name = "resized_dataset"
if not os.path.exists(main_path / dataset_name):
for file in glob.glob(str(path_to_train_and_valid) % source_dataset_name, recursive=True):
path_to_file = file.replace("\\", "/")
image = cv2.imread(path_to_file)
image = cv2.resize(image, (width, height))
image = cv2.resize(image, shape)
new_path = path_to_file.replace(
source_dataset_name, dataset_name)
self.write_image(image, new_path)
@ -78,6 +106,7 @@ if __name__ == "__main__":
if args.download:
data_manager.download_data()
data_manager.unzip_data("archive.zip", original_dataset_name)
data_manager.resize_dataset(original_dataset_name, 64, 64)
if args.resize:
data_manager.resize_dataset(args.source, tuple(args.shape))
if args.sobel:
data_manager.sobelx("resized_dataset")
data_manager.sobelx(args.source)