iumKC/data.py

163 lines
5.2 KiB
Python

import os
from copy import deepcopy
from glob import glob
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
class FootballSegDataset(Dataset):
def __init__(
self,
image_paths,
label_paths,
img_mean=[0.3468, 0.3885, 0.3321],
img_std=[0.2081, 0.2054, 0.2093],
label_mean=[0.3468, 0.3885, 0.3321],
label_std=[0.2081, 0.2054, 0.2093],
):
self.image_paths = image_paths
self.label_paths = label_paths
self.transform_img = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=img_mean, std=img_std
),
]
)
self.transform_label = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=label_mean, std=label_std
),
]
)
self.num_classes = 11
self.class_names = [
"Goal Bar",
"Referee",
"Advertisement",
"Ground",
"Ball",
"Coaches & Officials",
"Team A",
"Team B",
"Goalkeeper A",
"Goalkeeper B",
"Audience",
]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
label_map = Image.open(self.label_paths[idx]).convert("RGB")
image = self.transform_img(image)
label_map = self.transform_label(label_map)
return image, label_map
class SegmentationStatistics:
def __init__(self, label_paths, train_loader, test_loader, mean, std):
self.label_paths = label_paths
self.color_counts = []
self.train_loader = train_loader
self.test_loader = test_loader
self.mean = mean
self.std = std
def count_colors(self):
for path in self.label_paths:
label_map = Image.open(path).convert("RGB")
colors = set(label_map.getdata())
num_colors = len(colors)
self.color_counts.append(num_colors - 1)
def print_statistics(self):
max_classes = (
max(self.color_counts),
self.color_counts.index(max(self.color_counts)),
)
min_classes = (
min(self.color_counts),
self.color_counts.index(min(self.color_counts)),
)
avg_classes = sum(self.color_counts) / len(self.color_counts)
print(f"Train data size: {len(self.train_loader.dataset)}")
print(f"Test data size: {len(self.test_loader.dataset)}")
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
print(f"Mean: {self.mean}")
print(f"Std: {self.std}")
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
print(f"Class names: {self.train_loader.dataset.class_names}")
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
print(f"Avg classes: {avg_classes}")
def get_data():
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
download_dir = "./football_dataset"
if not os.path.exists(download_dir):
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
os.makedirs(download_dir)
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
for path in all_images]
return all_images, all_paths
def calculate_mean_std(image_paths):
mean = torch.zeros(3)
std = torch.zeros(3)
total_images = 0
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
image = transforms.ToTensor()(image)
mean += torch.mean(image, dim=(1, 2))
std += torch.std(image, dim=(1, 2))
total_images += 1
mean /= total_images
std /= total_images
return mean, std
def plot_random_images(indices, data_loader, suffix="train"):
plt.figure(figsize=(8, 8))
for i, index in enumerate(indices):
image, label_map = data_loader.dataset[index]
image = deepcopy(image).permute(1, 2, 0)
label_map = deepcopy(label_map).permute(1, 2, 0)
plt.subplot(5, 2, i * 2 + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(5, 2, i * 2 + 2)
plt.imshow(label_map)
plt.axis("off")
# Save the figure to a file instead of displaying it
plt.savefig(f"random_images_{suffix}.png", dpi=250, bbox_inches="tight")
plt.close() # Close the figure to free up memory