155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
import os
|
|
from copy import deepcopy
|
|
from glob import glob
|
|
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class FootballSegDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_paths,
|
|
label_paths,
|
|
mean=[0.3468, 0.3885, 0.3321],
|
|
std=[0.2081, 0.2054, 0.2093],
|
|
):
|
|
self.image_paths = image_paths
|
|
self.label_paths = label_paths
|
|
self.mean = mean
|
|
self.std = std
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
|
|
),
|
|
]
|
|
)
|
|
self.num_classes = 11
|
|
self.class_names = [
|
|
"Goal Bar",
|
|
"Referee",
|
|
"Advertisement",
|
|
"Ground",
|
|
"Ball",
|
|
"Coaches & Officials",
|
|
"Team A",
|
|
"Team B",
|
|
"Goalkeeper A",
|
|
"Goalkeeper B",
|
|
"Audience",
|
|
]
|
|
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
image = Image.open(self.image_paths[idx]).convert("RGB")
|
|
label_map = Image.open(self.label_paths[idx]).convert("RGB")
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
label_map = self.transform(label_map)
|
|
|
|
return image, label_map
|
|
|
|
|
|
class SegmentationStatistics:
|
|
def __init__(self, label_paths, train_loader, test_loader, mean, std):
|
|
self.label_paths = label_paths
|
|
self.color_counts = []
|
|
self.train_loader = train_loader
|
|
self.test_loader = test_loader
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def count_colors(self):
|
|
for path in self.label_paths:
|
|
label_map = Image.open(path).convert("RGB")
|
|
colors = set(label_map.getdata())
|
|
num_colors = len(colors)
|
|
self.color_counts.append(num_colors - 1)
|
|
|
|
def print_statistics(self):
|
|
max_classes = (
|
|
max(self.color_counts),
|
|
self.color_counts.index(max(self.color_counts)),
|
|
)
|
|
min_classes = (
|
|
min(self.color_counts),
|
|
self.color_counts.index(min(self.color_counts)),
|
|
)
|
|
avg_classes = sum(self.color_counts) / len(self.color_counts)
|
|
print(f"Train data size: {len(self.train_loader.dataset)}")
|
|
print(f"Test data size: {len(self.test_loader.dataset)}")
|
|
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
|
|
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
|
|
print(f"Mean: {self.mean}")
|
|
print(f"Std: {self.std}")
|
|
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
|
|
print(f"Class names: {self.train_loader.dataset.class_names}")
|
|
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
|
|
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
|
|
print(f"Avg classes: {avg_classes}")
|
|
|
|
|
|
def get_data():
|
|
api = KaggleApi()
|
|
api.authenticate()
|
|
|
|
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
|
|
|
|
download_dir = "./football_dataset"
|
|
|
|
if not os.path.exists(download_dir):
|
|
os.makedirs(download_dir)
|
|
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
|
|
|
all_images = glob(download_dir + "/images/*.jpg")
|
|
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
|
for path in all_images]
|
|
return all_images, all_paths
|
|
|
|
|
|
def calculate_mean_std(image_paths):
|
|
mean = torch.zeros(3)
|
|
std = torch.zeros(3)
|
|
total_images = 0
|
|
|
|
for image_path in image_paths:
|
|
image = Image.open(image_path).convert("RGB")
|
|
image = transforms.ToTensor()(image)
|
|
mean += torch.mean(image, dim=(1, 2))
|
|
std += torch.std(image, dim=(1, 2))
|
|
total_images += 1
|
|
|
|
mean /= total_images
|
|
std /= total_images
|
|
|
|
return mean, std
|
|
|
|
|
|
def plot_random_images(indices, data_loader, suffix='train'):
|
|
plt.figure(figsize=(8, 8))
|
|
for i, index in enumerate(indices):
|
|
image, label_map = data_loader.dataset[index]
|
|
image = deepcopy(image).permute(1, 2, 0)
|
|
label_map = deepcopy(label_map).permute(1, 2, 0)
|
|
|
|
plt.subplot(5, 2, i * 2 + 1)
|
|
plt.imshow(image)
|
|
plt.axis("off")
|
|
plt.subplot(5, 2, i * 2 + 2)
|
|
plt.imshow(label_map)
|
|
plt.axis("off")
|
|
|
|
# Save the figure to a file instead of displaying it
|
|
plt.savefig(f'random_images_{suffix}.png', dpi=250, bbox_inches='tight')
|
|
plt.close() # Close the figure to free up memory
|