import os from copy import deepcopy from glob import glob import matplotlib.pyplot as plt import torch import torchvision.transforms as transforms from PIL import Image from torch.utils.data import Dataset class FootballSegDataset(Dataset): def __init__( self, image_paths, label_paths, img_mean=[0.3468, 0.3885, 0.3321], img_std=[0.2081, 0.2054, 0.2093], label_mean=[0.3468, 0.3885, 0.3321], label_std=[0.2081, 0.2054, 0.2093], ): self.image_paths = image_paths self.label_paths = label_paths self.transform_img = transforms.Compose( [ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( mean=img_mean, std=img_std ), ] ) self.transform_label = transforms.Compose( [ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( mean=label_mean, std=label_std ), ] ) self.num_classes = 11 self.class_names = [ "Goal Bar", "Referee", "Advertisement", "Ground", "Ball", "Coaches & Officials", "Team A", "Team B", "Goalkeeper A", "Goalkeeper B", "Audience", ] def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert("RGB") label_map = Image.open(self.label_paths[idx]).convert("RGB") image = self.transform_img(image) label_map = self.transform_label(label_map) return image, label_map class SegmentationStatistics: def __init__(self, label_paths, train_loader, test_loader, mean, std): self.label_paths = label_paths self.color_counts = [] self.train_loader = train_loader self.test_loader = test_loader self.mean = mean self.std = std def count_colors(self): for path in self.label_paths: label_map = Image.open(path).convert("RGB") colors = set(label_map.getdata()) num_colors = len(colors) self.color_counts.append(num_colors - 1) def print_statistics(self): max_classes = ( max(self.color_counts), self.color_counts.index(max(self.color_counts)), ) min_classes = ( min(self.color_counts), self.color_counts.index(min(self.color_counts)), ) avg_classes = sum(self.color_counts) / len(self.color_counts) print(f"Train data size: {len(self.train_loader.dataset)}") print(f"Test data size: {len(self.test_loader.dataset)}") print(f"Data shape: {self.train_loader.dataset[0][0].shape}") print(f"Label shape: {self.train_loader.dataset[0][1].shape}") print(f"Mean: {self.mean}") print(f"Std: {self.std}") print(f"Number of classes: {self.train_loader.dataset.num_classes}") print(f"Class names: {self.train_loader.dataset.class_names}") print(f"Max classes: {max_classes[0]} at index {max_classes[1]}") print(f"Min classes: {min_classes[0]} at index {min_classes[1]}") print(f"Avg classes: {avg_classes}") def get_data(): dataset_slug = "sadhliroomyprime/football-semantic-segmentation" download_dir = "./football_dataset" if not os.path.exists(download_dir): from kaggle.api.kaggle_api_extended import KaggleApi api = KaggleApi() api.authenticate() os.makedirs(download_dir) api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) all_images = glob(download_dir + "/images/*.jpg") all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images] return all_images, all_paths def calculate_mean_std(image_paths): mean = torch.zeros(3) std = torch.zeros(3) total_images = 0 for image_path in image_paths: image = Image.open(image_path).convert("RGB") image = transforms.ToTensor()(image) mean += torch.mean(image, dim=(1, 2)) std += torch.std(image, dim=(1, 2)) total_images += 1 mean /= total_images std /= total_images return mean, std def plot_random_images(indices, data_loader, suffix="train"): plt.figure(figsize=(8, 8)) for i, index in enumerate(indices): image, label_map = data_loader.dataset[index] image = deepcopy(image).permute(1, 2, 0) label_map = deepcopy(label_map).permute(1, 2, 0) plt.subplot(5, 2, i * 2 + 1) plt.imshow(image) plt.axis("off") plt.subplot(5, 2, i * 2 + 2) plt.imshow(label_map) plt.axis("off") # Save the figure to a file instead of displaying it plt.savefig(f"random_images_{suffix}.png", dpi=250, bbox_inches="tight") plt.close() # Close the figure to free up memory