59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
|
|
def mkdir_if_not_exists(path):
|
|
try:
|
|
os.mkdir(path)
|
|
except FileExistsError:
|
|
pass
|
|
|
|
|
|
def bulk_copy(file_names: list[str], input_dir, output):
|
|
for i, file in enumerate(file_names):
|
|
shutil.copy(
|
|
os.path.join(input_dir, file), os.path.join(output, str(i) + '.' + file.split('.')[1])
|
|
)
|
|
|
|
|
|
def split_houzz_dataset(
|
|
raw_path: str,
|
|
train_out_folder: str,
|
|
test_out_folder: str,
|
|
train_test_ratio: float = 0.8
|
|
):
|
|
image_dir = Path(raw_path)
|
|
|
|
classes = []
|
|
for maybe_dir in os.listdir(image_dir):
|
|
class_dir = os.path.join(image_dir, maybe_dir)
|
|
if os.path.isdir(class_dir):
|
|
classes.append(maybe_dir)
|
|
|
|
print(f'Found {len(classes)} classes')
|
|
for cls in classes:
|
|
mkdir_if_not_exists(os.path.join(train_out_folder, str(cls)))
|
|
mkdir_if_not_exists(os.path.join(test_out_folder, str(cls)))
|
|
|
|
raw_folders = [directory for directory in image_dir.iterdir() if directory.is_dir()]
|
|
for raw_directory, cls in zip(raw_folders, classes):
|
|
raw_files = os.listdir(raw_directory)
|
|
print(f'{raw_directory}: {len(raw_files)}')
|
|
split_point = round(len(raw_files) * train_test_ratio)
|
|
train_files = raw_files[:split_point]
|
|
print(f'\tTrain files: {len(train_files)}')
|
|
test_files = raw_files[split_point + 1:]
|
|
print(f'\tTest files: {len(test_files)}')
|
|
print('Copying... ', end='')
|
|
bulk_copy(test_files, raw_directory, os.path.join(TEST_OUTPUT, str(cls)))
|
|
bulk_copy(train_files, raw_directory, os.path.join(TRAIN_OUTPUT, str(cls)))
|
|
print('Done.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
HOUZZ_DATASET_PATH = 'images/raw/houzz'
|
|
TRAIN_OUTPUT = 'images/train'
|
|
TEST_OUTPUT = 'images/test'
|
|
split_houzz_dataset(HOUZZ_DATASET_PATH, TRAIN_OUTPUT, TEST_OUTPUT)
|