This commit is contained in:
Karol Cyganik 2024-03-13 23:33:36 +01:00
parent 7573648527
commit 59ef151e53
3 changed files with 131 additions and 86 deletions

170
data.py
View File

@ -1,4 +1,5 @@
import os import os
from copy import deepcopy
from glob import glob from glob import glob
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -6,17 +7,44 @@ import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from kaggle.api.kaggle_api_extended import KaggleApi from kaggle.api.kaggle_api_extended import KaggleApi
from PIL import Image from PIL import Image
from sklearn.model_selection import train_test_split from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
class FootballSegDataset(Dataset): class FootballSegDataset(Dataset):
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None): def __init__(
self,
image_paths,
label_paths,
mean=[0.3468, 0.3885, 0.3321],
std=[0.2081, 0.2054, 0.2093],
):
self.image_paths = image_paths self.image_paths = image_paths
self.label_paths = label_paths self.label_paths = label_paths
self.mean = mean self.mean = mean
self.std = std self.std = std
self.transform = transform self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
),
]
)
self.num_classes = 11
self.class_names = [
"Goal Bar",
"Referee",
"Advertisement",
"Ground",
"Ball",
"Coaches & Officials",
"Team A",
"Team B",
"Goalkeeper A",
"Goalkeeper B",
"Audience",
]
def __len__(self): def __len__(self):
return len(self.image_paths) return len(self.image_paths)
@ -32,6 +60,45 @@ class FootballSegDataset(Dataset):
return image, label_map return image, label_map
class SegmentationStatistics:
def __init__(self, label_paths, train_loader, test_loader, mean, std):
self.label_paths = label_paths
self.color_counts = []
self.train_loader = train_loader
self.test_loader = test_loader
self.mean = mean
self.std = std
def count_colors(self):
for path in self.label_paths:
label_map = Image.open(path).convert("RGB")
colors = set(label_map.getdata())
num_colors = len(colors)
self.color_counts.append(num_colors - 1)
def print_statistics(self):
max_classes = (
max(self.color_counts),
self.color_counts.index(max(self.color_counts)),
)
min_classes = (
min(self.color_counts),
self.color_counts.index(min(self.color_counts)),
)
avg_classes = sum(self.color_counts) / len(self.color_counts)
print(f"Train data size: {len(self.train_loader.dataset)}")
print(f"Test data size: {len(self.test_loader.dataset)}")
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
print(f"Mean: {self.mean}")
print(f"Std: {self.std}")
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
print(f"Class names: {self.train_loader.dataset.class_names}")
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
print(f"Avg classes: {avg_classes}")
def get_data(): def get_data():
api = KaggleApi() api = KaggleApi()
api.authenticate() api.authenticate()
@ -45,8 +112,7 @@ def get_data():
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg") all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png") all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
for path in all_images]
return all_images, all_paths return all_images, all_paths
@ -68,86 +134,18 @@ def calculate_mean_std(image_paths):
return mean, std return mean, std
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256): def plot_random_images(indices, data_loader):
image = Image.open(path).convert("RGB") plt.figure(figsize=(8, 8))
transform = transforms.Compose([ for i, index in enumerate(indices):
transforms.Resize((SIZE, SIZE)), image, label_map = data_loader.dataset[index]
transforms.ToTensor(), image = deepcopy(image).permute(1, 2, 0)
transforms.Normalize(mean=mean, std=std) label_map = deepcopy(label_map).permute(1, 2, 0)
])
image = transform(image)
return image
plt.subplot(5, 2, i * 2 + 1)
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True): plt.imshow(image)
if calculate_stats: plt.axis("off")
mean, std = calculate_mean_std(image_paths) plt.subplot(5, 2, i * 2 + 2)
print("Mean:", mean) plt.imshow(label_map)
print("Std:", std) plt.axis("off")
else:
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
images = []
label_maps = []
for image_path, label_path in zip(image_paths, label_paths):
image = load_image(image_path, mean, std, SIZE=SIZE)
label_map = load_image(label_path, mean, std, SIZE=SIZE)
images.append(image)
label_maps.append(label_map)
images = torch.stack(images)
label_maps = torch.stack(label_maps)
return images, label_maps
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
image_np = image.permute(1, 2, 0).cpu().numpy()
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
plt.imshow(image_np, alpha=alpha_1)
plt.imshow(label_map_np, alpha=alpha_2)
plt.axis('off')
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
n_rows, n_cols = GRID
n_images = n_rows * n_cols
plt.figure(figsize=SIZE)
i = 1
for image, label_map in zip(images, label_maps):
plt.subplot(n_rows, n_cols, i)
show_map(image, label_map)
i += 1
if i > n_images:
break
plt.show() plt.show()
def main():
all_images, all_paths = get_data()
images, label_maps = load_data(
all_images, all_paths, calculate_stats=False)
show_maps(images, label_maps, GRID=[3, 3])
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
all_images, all_paths, test_size=0.2, random_state=42)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
0.2081, 0.2054, 0.2093])
])
train_dataset = FootballSegDataset(
image_train_paths, label_train_paths, transform=transform)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
if __name__ == "__main__":
main()

40
main.py Normal file
View File

@ -0,0 +1,40 @@
import random
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from data import (
FootballSegDataset,
SegmentationStatistics,
calculate_mean_std,
get_data,
plot_random_images,
)
def main():
all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
)
mean, std = calculate_mean_std(image_train_paths)
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
train_indices = random.sample(range(len(train_loader.dataset)), 5)
test_indices = random.sample(range(len(test_loader.dataset)), 5)
plot_random_images(train_indices, train_loader)
plot_random_images(test_indices, test_loader)
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
statistics.count_colors()
statistics.print_statistics()
if __name__ == "__main__":
main()

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
kaggle==1.6.6
matplotlib==3.6.3
Pillow==9.3.0
Pillow==10.2.0
scikit_learn==1.2.2
torch==2.0.0+cu117
torchvision==0.15.1+cu117