AI-Tech-WKO-Projekt/data/prepare_alpha_dataset.py

59 lines
1.9 KiB
Python
Raw Permalink Normal View History

2023-02-03 16:09:03 +01:00
import os
import shutil
from pathlib import Path
def mkdir_if_not_exists(path):
try:
os.mkdir(path)
except FileExistsError:
pass
def bulk_copy(file_names: list[str], input_dir, output):
for i, file in enumerate(file_names):
shutil.copy(
os.path.join(input_dir, file), os.path.join(output, str(i) + '.' + file.split('.')[1])
)
def split_houzz_dataset(
raw_path: str,
train_out_folder: str,
test_out_folder: str,
train_test_ratio: float = 0.8
):
image_dir = Path(raw_path)
classes = []
for maybe_dir in os.listdir(image_dir):
class_dir = os.path.join(image_dir, maybe_dir)
if os.path.isdir(class_dir):
classes.append(maybe_dir)
print(f'Found {len(classes)} classes')
for cls in classes:
mkdir_if_not_exists(os.path.join(train_out_folder, str(cls)))
mkdir_if_not_exists(os.path.join(test_out_folder, str(cls)))
raw_folders = [directory for directory in image_dir.iterdir() if directory.is_dir()]
for raw_directory, cls in zip(raw_folders, classes):
raw_files = os.listdir(raw_directory)
print(f'{raw_directory}: {len(raw_files)}')
split_point = round(len(raw_files) * train_test_ratio)
train_files = raw_files[:split_point]
print(f'\tTrain files: {len(train_files)}')
test_files = raw_files[split_point + 1:]
print(f'\tTest files: {len(test_files)}')
print('Copying... ', end='')
bulk_copy(test_files, raw_directory, os.path.join(TEST_OUTPUT, str(cls)))
bulk_copy(train_files, raw_directory, os.path.join(TRAIN_OUTPUT, str(cls)))
print('Done.')
if __name__ == '__main__':
HOUZZ_DATASET_PATH = 'images/raw/houzz'
TRAIN_OUTPUT = 'images/train'
TEST_OUTPUT = 'images/test'
split_houzz_dataset(HOUZZ_DATASET_PATH, TRAIN_OUTPUT, TEST_OUTPUT)