This commit is contained in:
Karol Cyganik 2024-03-20 12:47:26 +01:00
commit cdf35f3e22
4 changed files with 140 additions and 87 deletions

View File

@ -1,2 +1,10 @@
# ium # Inżynieria Uczenia Maszynowego
To run the code you need to install all the packages from the requirements.txt file. Then just run main.py file.
Remember to have your Kaggle API token in the .kaggle folder in your home directory.
```bash
pip install -r requirements.txt
python main.py
```

170
data.py
View File

@ -1,4 +1,5 @@
import os import os
from copy import deepcopy
from glob import glob from glob import glob
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -6,17 +7,44 @@ import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from kaggle.api.kaggle_api_extended import KaggleApi from kaggle.api.kaggle_api_extended import KaggleApi
from PIL import Image from PIL import Image
from sklearn.model_selection import train_test_split from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
class FootballSegDataset(Dataset): class FootballSegDataset(Dataset):
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None): def __init__(
self,
image_paths,
label_paths,
mean=[0.3468, 0.3885, 0.3321],
std=[0.2081, 0.2054, 0.2093],
):
self.image_paths = image_paths self.image_paths = image_paths
self.label_paths = label_paths self.label_paths = label_paths
self.mean = mean self.mean = mean
self.std = std self.std = std
self.transform = transform self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
),
]
)
self.num_classes = 11
self.class_names = [
"Goal Bar",
"Referee",
"Advertisement",
"Ground",
"Ball",
"Coaches & Officials",
"Team A",
"Team B",
"Goalkeeper A",
"Goalkeeper B",
"Audience",
]
def __len__(self): def __len__(self):
return len(self.image_paths) return len(self.image_paths)
@ -32,6 +60,45 @@ class FootballSegDataset(Dataset):
return image, label_map return image, label_map
class SegmentationStatistics:
def __init__(self, label_paths, train_loader, test_loader, mean, std):
self.label_paths = label_paths
self.color_counts = []
self.train_loader = train_loader
self.test_loader = test_loader
self.mean = mean
self.std = std
def count_colors(self):
for path in self.label_paths:
label_map = Image.open(path).convert("RGB")
colors = set(label_map.getdata())
num_colors = len(colors)
self.color_counts.append(num_colors - 1)
def print_statistics(self):
max_classes = (
max(self.color_counts),
self.color_counts.index(max(self.color_counts)),
)
min_classes = (
min(self.color_counts),
self.color_counts.index(min(self.color_counts)),
)
avg_classes = sum(self.color_counts) / len(self.color_counts)
print(f"Train data size: {len(self.train_loader.dataset)}")
print(f"Test data size: {len(self.test_loader.dataset)}")
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
print(f"Mean: {self.mean}")
print(f"Std: {self.std}")
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
print(f"Class names: {self.train_loader.dataset.class_names}")
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
print(f"Avg classes: {avg_classes}")
def get_data(): def get_data():
api = KaggleApi() api = KaggleApi()
api.authenticate() api.authenticate()
@ -45,8 +112,7 @@ def get_data():
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg") all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png") all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
for path in all_images]
return all_images, all_paths return all_images, all_paths
@ -68,86 +134,18 @@ def calculate_mean_std(image_paths):
return mean, std return mean, std
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256): def plot_random_images(indices, data_loader):
image = Image.open(path).convert("RGB") plt.figure(figsize=(8, 8))
transform = transforms.Compose([ for i, index in enumerate(indices):
transforms.Resize((SIZE, SIZE)), image, label_map = data_loader.dataset[index]
transforms.ToTensor(), image = deepcopy(image).permute(1, 2, 0)
transforms.Normalize(mean=mean, std=std) label_map = deepcopy(label_map).permute(1, 2, 0)
])
image = transform(image)
return image
plt.subplot(5, 2, i * 2 + 1)
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True): plt.imshow(image)
if calculate_stats: plt.axis("off")
mean, std = calculate_mean_std(image_paths) plt.subplot(5, 2, i * 2 + 2)
print("Mean:", mean) plt.imshow(label_map)
print("Std:", std) plt.axis("off")
else:
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
images = []
label_maps = []
for image_path, label_path in zip(image_paths, label_paths):
image = load_image(image_path, mean, std, SIZE=SIZE)
label_map = load_image(label_path, mean, std, SIZE=SIZE)
images.append(image)
label_maps.append(label_map)
images = torch.stack(images)
label_maps = torch.stack(label_maps)
return images, label_maps
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
image_np = image.permute(1, 2, 0).cpu().numpy()
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
plt.imshow(image_np, alpha=alpha_1)
plt.imshow(label_map_np, alpha=alpha_2)
plt.axis('off')
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
n_rows, n_cols = GRID
n_images = n_rows * n_cols
plt.figure(figsize=SIZE)
i = 1
for image, label_map in zip(images, label_maps):
plt.subplot(n_rows, n_cols, i)
show_map(image, label_map)
i += 1
if i > n_images:
break
plt.show() plt.show()
def main():
all_images, all_paths = get_data()
images, label_maps = load_data(
all_images, all_paths, calculate_stats=False)
show_maps(images, label_maps, GRID=[3, 3])
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
all_images, all_paths, test_size=0.2, random_state=42)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
0.2081, 0.2054, 0.2093])
])
train_dataset = FootballSegDataset(
image_train_paths, label_train_paths, transform=transform)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
if __name__ == "__main__":
main()

40
main.py Normal file
View File

@ -0,0 +1,40 @@
import random
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from data import (
FootballSegDataset,
SegmentationStatistics,
calculate_mean_std,
get_data,
plot_random_images,
)
def main():
all_images, all_paths = get_data()
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
)
mean, std = calculate_mean_std(image_train_paths)
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
train_indices = random.sample(range(len(train_loader.dataset)), 5)
test_indices = random.sample(range(len(test_loader.dataset)), 5)
plot_random_images(train_indices, train_loader)
plot_random_images(test_indices, test_loader)
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
statistics.count_colors()
statistics.print_statistics()
if __name__ == "__main__":
main()

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
kaggle==1.6.6
matplotlib==3.6.3
Pillow==9.3.0
Pillow==10.2.0
scikit_learn==1.2.2
torch==2.0.0+cu117
torchvision==0.15.1+cu117