iumKC/data.py
2024-03-20 12:58:03 +01:00

155 lines
4.9 KiB
Python

import os
from copy import deepcopy
from glob import glob
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from kaggle.api.kaggle_api_extended import KaggleApi
from PIL import Image
from torch.utils.data import Dataset
class FootballSegDataset(Dataset):
def __init__(
self,
image_paths,
label_paths,
mean=[0.3468, 0.3885, 0.3321],
std=[0.2081, 0.2054, 0.2093],
):
self.image_paths = image_paths
self.label_paths = label_paths
self.mean = mean
self.std = std
self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
),
]
)
self.num_classes = 11
self.class_names = [
"Goal Bar",
"Referee",
"Advertisement",
"Ground",
"Ball",
"Coaches & Officials",
"Team A",
"Team B",
"Goalkeeper A",
"Goalkeeper B",
"Audience",
]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
label_map = Image.open(self.label_paths[idx]).convert("RGB")
if self.transform:
image = self.transform(image)
label_map = self.transform(label_map)
return image, label_map
class SegmentationStatistics:
def __init__(self, label_paths, train_loader, test_loader, mean, std):
self.label_paths = label_paths
self.color_counts = []
self.train_loader = train_loader
self.test_loader = test_loader
self.mean = mean
self.std = std
def count_colors(self):
for path in self.label_paths:
label_map = Image.open(path).convert("RGB")
colors = set(label_map.getdata())
num_colors = len(colors)
self.color_counts.append(num_colors - 1)
def print_statistics(self):
max_classes = (
max(self.color_counts),
self.color_counts.index(max(self.color_counts)),
)
min_classes = (
min(self.color_counts),
self.color_counts.index(min(self.color_counts)),
)
avg_classes = sum(self.color_counts) / len(self.color_counts)
print(f"Train data size: {len(self.train_loader.dataset)}")
print(f"Test data size: {len(self.test_loader.dataset)}")
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
print(f"Mean: {self.mean}")
print(f"Std: {self.std}")
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
print(f"Class names: {self.train_loader.dataset.class_names}")
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
print(f"Avg classes: {avg_classes}")
def get_data():
api = KaggleApi()
api.authenticate()
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
download_dir = "./football_dataset"
if not os.path.exists(download_dir):
os.makedirs(download_dir)
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
for path in all_images]
return all_images, all_paths
def calculate_mean_std(image_paths):
mean = torch.zeros(3)
std = torch.zeros(3)
total_images = 0
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
image = transforms.ToTensor()(image)
mean += torch.mean(image, dim=(1, 2))
std += torch.std(image, dim=(1, 2))
total_images += 1
mean /= total_images
std /= total_images
return mean, std
def plot_random_images(indices, data_loader, suffix='train'):
plt.figure(figsize=(8, 8))
for i, index in enumerate(indices):
image, label_map = data_loader.dataset[index]
image = deepcopy(image).permute(1, 2, 0)
label_map = deepcopy(label_map).permute(1, 2, 0)
plt.subplot(5, 2, i * 2 + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(5, 2, i * 2 + 2)
plt.imshow(label_map)
plt.axis("off")
# Save the figure to a file instead of displaying it
plt.savefig(f'random_images_{suffix}.png', dpi=250, bbox_inches='tight')
plt.close() # Close the figure to free up memory