iumKC/data.py

163 lines
5.2 KiB
Python
Raw Normal View History

2024-03-13 15:15:34 +01:00
import os
2024-03-13 23:33:36 +01:00
from copy import deepcopy
2024-03-13 15:15:34 +01:00
from glob import glob
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from PIL import Image
2024-03-13 23:33:36 +01:00
from torch.utils.data import Dataset
2024-03-13 15:15:34 +01:00
class FootballSegDataset(Dataset):
2024-03-13 23:33:36 +01:00
def __init__(
self,
image_paths,
label_paths,
2024-04-14 22:55:07 +02:00
img_mean=[0.3468, 0.3885, 0.3321],
img_std=[0.2081, 0.2054, 0.2093],
label_mean=[0.3468, 0.3885, 0.3321],
label_std=[0.2081, 0.2054, 0.2093],
2024-03-13 23:33:36 +01:00
):
2024-03-13 15:15:34 +01:00
self.image_paths = image_paths
self.label_paths = label_paths
2024-04-14 22:55:07 +02:00
self.transform_img = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=img_mean, std=img_std
),
]
)
self.transform_label = transforms.Compose(
2024-03-13 23:33:36 +01:00
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
2024-04-14 22:55:07 +02:00
mean=label_mean, std=label_std
2024-03-13 23:33:36 +01:00
),
]
)
self.num_classes = 11
self.class_names = [
"Goal Bar",
"Referee",
"Advertisement",
"Ground",
"Ball",
"Coaches & Officials",
"Team A",
"Team B",
"Goalkeeper A",
"Goalkeeper B",
"Audience",
]
2024-03-13 15:15:34 +01:00
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
label_map = Image.open(self.label_paths[idx]).convert("RGB")
2024-04-14 22:55:07 +02:00
image = self.transform_img(image)
label_map = self.transform_label(label_map)
2024-03-13 15:15:34 +01:00
return image, label_map
2024-03-13 23:33:36 +01:00
class SegmentationStatistics:
def __init__(self, label_paths, train_loader, test_loader, mean, std):
self.label_paths = label_paths
self.color_counts = []
self.train_loader = train_loader
self.test_loader = test_loader
self.mean = mean
self.std = std
def count_colors(self):
for path in self.label_paths:
label_map = Image.open(path).convert("RGB")
colors = set(label_map.getdata())
num_colors = len(colors)
self.color_counts.append(num_colors - 1)
def print_statistics(self):
max_classes = (
max(self.color_counts),
self.color_counts.index(max(self.color_counts)),
)
min_classes = (
min(self.color_counts),
self.color_counts.index(min(self.color_counts)),
)
avg_classes = sum(self.color_counts) / len(self.color_counts)
print(f"Train data size: {len(self.train_loader.dataset)}")
print(f"Test data size: {len(self.test_loader.dataset)}")
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
print(f"Mean: {self.mean}")
print(f"Std: {self.std}")
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
print(f"Class names: {self.train_loader.dataset.class_names}")
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
print(f"Avg classes: {avg_classes}")
2024-03-13 15:15:34 +01:00
def get_data():
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
download_dir = "./football_dataset"
if not os.path.exists(download_dir):
2024-04-03 09:12:18 +02:00
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
2024-03-13 15:15:34 +01:00
os.makedirs(download_dir)
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg")
2024-04-14 22:55:07 +02:00
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
for path in all_images]
2024-03-13 15:15:34 +01:00
return all_images, all_paths
def calculate_mean_std(image_paths):
mean = torch.zeros(3)
std = torch.zeros(3)
total_images = 0
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
image = transforms.ToTensor()(image)
mean += torch.mean(image, dim=(1, 2))
std += torch.std(image, dim=(1, 2))
total_images += 1
mean /= total_images
std /= total_images
return mean, std
2024-04-03 09:12:18 +02:00
def plot_random_images(indices, data_loader, suffix="train"):
2024-03-13 23:33:36 +01:00
plt.figure(figsize=(8, 8))
for i, index in enumerate(indices):
image, label_map = data_loader.dataset[index]
image = deepcopy(image).permute(1, 2, 0)
label_map = deepcopy(label_map).permute(1, 2, 0)
2024-03-13 15:15:34 +01:00
2024-03-13 23:33:36 +01:00
plt.subplot(5, 2, i * 2 + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(5, 2, i * 2 + 2)
plt.imshow(label_map)
plt.axis("off")
2024-03-13 15:15:34 +01:00
# Save the figure to a file instead of displaying it
2024-04-03 09:12:18 +02:00
plt.savefig(f"random_images_{suffix}.png", dpi=250, bbox_inches="tight")
plt.close() # Close the figure to free up memory