From 7573648527a0a08738f85c9229bb892e070bcecb Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Wed, 13 Mar 2024 15:15:34 +0100 Subject: [PATCH] data file --- .gitignore | 2 + data.py | 153 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 data.py diff --git a/.gitignore b/.gitignore index 33e7851..1cdc336 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,5 @@ ipython_config.py # Remove previous ipynb_checkpoints # git rm -r .ipynb_checkpoints/ +football_dataset/ +venvium/ \ No newline at end of file diff --git a/data.py b/data.py new file mode 100644 index 0000000..a6ac3d2 --- /dev/null +++ b/data.py @@ -0,0 +1,153 @@ +import os +from glob import glob + +import matplotlib.pyplot as plt +import torch +import torchvision.transforms as transforms +from kaggle.api.kaggle_api_extended import KaggleApi +from PIL import Image +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader, Dataset + + +class FootballSegDataset(Dataset): + def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None): + self.image_paths = image_paths + self.label_paths = label_paths + self.mean = mean + self.std = std + self.transform = transform + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image = Image.open(self.image_paths[idx]).convert("RGB") + label_map = Image.open(self.label_paths[idx]).convert("RGB") + + if self.transform: + image = self.transform(image) + label_map = self.transform(label_map) + + return image, label_map + + +def get_data(): + api = KaggleApi() + api.authenticate() + + dataset_slug = "sadhliroomyprime/football-semantic-segmentation" + + download_dir = "./football_dataset" + + if not os.path.exists(download_dir): + os.makedirs(download_dir) + api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) + + all_images = glob(download_dir + "/images/*.jpg") + all_paths = [path.replace(".jpg", ".jpg___fuse.png") + for path in all_images] + return all_images, all_paths + + +def calculate_mean_std(image_paths): + mean = torch.zeros(3) + std = torch.zeros(3) + total_images = 0 + + for image_path in image_paths: + image = Image.open(image_path).convert("RGB") + image = transforms.ToTensor()(image) + mean += torch.mean(image, dim=(1, 2)) + std += torch.std(image, dim=(1, 2)) + total_images += 1 + + mean /= total_images + std /= total_images + + return mean, std + + +def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256): + image = Image.open(path).convert("RGB") + transform = transforms.Compose([ + transforms.Resize((SIZE, SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + image = transform(image) + return image + + +def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True): + if calculate_stats: + mean, std = calculate_mean_std(image_paths) + print("Mean:", mean) + print("Std:", std) + else: + mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093] + + images = [] + label_maps = [] + for image_path, label_path in zip(image_paths, label_paths): + image = load_image(image_path, mean, std, SIZE=SIZE) + label_map = load_image(label_path, mean, std, SIZE=SIZE) + images.append(image) + label_maps.append(label_map) + images = torch.stack(images) + label_maps = torch.stack(label_maps) + return images, label_maps + + +def show_map(image, label_map, alpha_1=1, alpha_2=0.7): + image_np = image.permute(1, 2, 0).cpu().numpy() + label_map_np = label_map.permute(1, 2, 0).cpu().numpy() + + plt.imshow(image_np, alpha=alpha_1) + plt.imshow(label_map_np, alpha=alpha_2) + plt.axis('off') + + +def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)): + n_rows, n_cols = GRID + n_images = n_rows * n_cols + plt.figure(figsize=SIZE) + + i = 1 + for image, label_map in zip(images, label_maps): + plt.subplot(n_rows, n_cols, i) + show_map(image, label_map) + + i += 1 + if i > n_images: + break + + plt.show() + + +def main(): + all_images, all_paths = get_data() + images, label_maps = load_data( + all_images, all_paths, calculate_stats=False) + show_maps(images, label_maps, GRID=[3, 3]) + image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split( + all_images, all_paths, test_size=0.2, random_state=42) + + transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[ + 0.2081, 0.2054, 0.2093]) + ]) + + train_dataset = FootballSegDataset( + image_train_paths, label_train_paths, transform=transform) + test_dataset = FootballSegDataset( + image_test_paths, label_test_paths, transform=transform) + + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) + + +if __name__ == "__main__": + main()