update
This commit is contained in:
parent
7573648527
commit
59ef151e53
170
data.py
170
data.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from glob import glob
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -6,17 +7,44 @@ import torch
|
||||
import torchvision.transforms as transforms
|
||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||
from PIL import Image
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class FootballSegDataset(Dataset):
|
||||
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None):
|
||||
def __init__(
|
||||
self,
|
||||
image_paths,
|
||||
label_paths,
|
||||
mean=[0.3468, 0.3885, 0.3321],
|
||||
std=[0.2081, 0.2054, 0.2093],
|
||||
):
|
||||
self.image_paths = image_paths
|
||||
self.label_paths = label_paths
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.transform = transform
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
|
||||
),
|
||||
]
|
||||
)
|
||||
self.num_classes = 11
|
||||
self.class_names = [
|
||||
"Goal Bar",
|
||||
"Referee",
|
||||
"Advertisement",
|
||||
"Ground",
|
||||
"Ball",
|
||||
"Coaches & Officials",
|
||||
"Team A",
|
||||
"Team B",
|
||||
"Goalkeeper A",
|
||||
"Goalkeeper B",
|
||||
"Audience",
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
@ -32,6 +60,45 @@ class FootballSegDataset(Dataset):
|
||||
return image, label_map
|
||||
|
||||
|
||||
class SegmentationStatistics:
|
||||
def __init__(self, label_paths, train_loader, test_loader, mean, std):
|
||||
self.label_paths = label_paths
|
||||
self.color_counts = []
|
||||
self.train_loader = train_loader
|
||||
self.test_loader = test_loader
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def count_colors(self):
|
||||
for path in self.label_paths:
|
||||
label_map = Image.open(path).convert("RGB")
|
||||
colors = set(label_map.getdata())
|
||||
num_colors = len(colors)
|
||||
self.color_counts.append(num_colors - 1)
|
||||
|
||||
def print_statistics(self):
|
||||
max_classes = (
|
||||
max(self.color_counts),
|
||||
self.color_counts.index(max(self.color_counts)),
|
||||
)
|
||||
min_classes = (
|
||||
min(self.color_counts),
|
||||
self.color_counts.index(min(self.color_counts)),
|
||||
)
|
||||
avg_classes = sum(self.color_counts) / len(self.color_counts)
|
||||
print(f"Train data size: {len(self.train_loader.dataset)}")
|
||||
print(f"Test data size: {len(self.test_loader.dataset)}")
|
||||
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
|
||||
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
|
||||
print(f"Mean: {self.mean}")
|
||||
print(f"Std: {self.std}")
|
||||
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
|
||||
print(f"Class names: {self.train_loader.dataset.class_names}")
|
||||
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
|
||||
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
|
||||
print(f"Avg classes: {avg_classes}")
|
||||
|
||||
|
||||
def get_data():
|
||||
api = KaggleApi()
|
||||
api.authenticate()
|
||||
@ -45,8 +112,7 @@ def get_data():
|
||||
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
||||
|
||||
all_images = glob(download_dir + "/images/*.jpg")
|
||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
||||
for path in all_images]
|
||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
|
||||
return all_images, all_paths
|
||||
|
||||
|
||||
@ -68,86 +134,18 @@ def calculate_mean_std(image_paths):
|
||||
return mean, std
|
||||
|
||||
|
||||
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256):
|
||||
image = Image.open(path).convert("RGB")
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((SIZE, SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
image = transform(image)
|
||||
return image
|
||||
def plot_random_images(indices, data_loader):
|
||||
plt.figure(figsize=(8, 8))
|
||||
for i, index in enumerate(indices):
|
||||
image, label_map = data_loader.dataset[index]
|
||||
image = deepcopy(image).permute(1, 2, 0)
|
||||
label_map = deepcopy(label_map).permute(1, 2, 0)
|
||||
|
||||
|
||||
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True):
|
||||
if calculate_stats:
|
||||
mean, std = calculate_mean_std(image_paths)
|
||||
print("Mean:", mean)
|
||||
print("Std:", std)
|
||||
else:
|
||||
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
|
||||
|
||||
images = []
|
||||
label_maps = []
|
||||
for image_path, label_path in zip(image_paths, label_paths):
|
||||
image = load_image(image_path, mean, std, SIZE=SIZE)
|
||||
label_map = load_image(label_path, mean, std, SIZE=SIZE)
|
||||
images.append(image)
|
||||
label_maps.append(label_map)
|
||||
images = torch.stack(images)
|
||||
label_maps = torch.stack(label_maps)
|
||||
return images, label_maps
|
||||
|
||||
|
||||
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
|
||||
image_np = image.permute(1, 2, 0).cpu().numpy()
|
||||
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
plt.imshow(image_np, alpha=alpha_1)
|
||||
plt.imshow(label_map_np, alpha=alpha_2)
|
||||
plt.axis('off')
|
||||
|
||||
|
||||
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
|
||||
n_rows, n_cols = GRID
|
||||
n_images = n_rows * n_cols
|
||||
plt.figure(figsize=SIZE)
|
||||
|
||||
i = 1
|
||||
for image, label_map in zip(images, label_maps):
|
||||
plt.subplot(n_rows, n_cols, i)
|
||||
show_map(image, label_map)
|
||||
|
||||
i += 1
|
||||
if i > n_images:
|
||||
break
|
||||
plt.subplot(5, 2, i * 2 + 1)
|
||||
plt.imshow(image)
|
||||
plt.axis("off")
|
||||
plt.subplot(5, 2, i * 2 + 2)
|
||||
plt.imshow(label_map)
|
||||
plt.axis("off")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
all_images, all_paths = get_data()
|
||||
images, label_maps = load_data(
|
||||
all_images, all_paths, calculate_stats=False)
|
||||
show_maps(images, label_maps, GRID=[3, 3])
|
||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
|
||||
all_images, all_paths, test_size=0.2, random_state=42)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
|
||||
0.2081, 0.2054, 0.2093])
|
||||
])
|
||||
|
||||
train_dataset = FootballSegDataset(
|
||||
image_train_paths, label_train_paths, transform=transform)
|
||||
test_dataset = FootballSegDataset(
|
||||
image_test_paths, label_test_paths, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
40
main.py
Normal file
40
main.py
Normal file
@ -0,0 +1,40 @@
|
||||
import random
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from data import (
|
||||
FootballSegDataset,
|
||||
SegmentationStatistics,
|
||||
calculate_mean_std,
|
||||
get_data,
|
||||
plot_random_images,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
all_images, all_paths = get_data()
|
||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||||
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||||
)
|
||||
|
||||
mean, std = calculate_mean_std(image_train_paths)
|
||||
|
||||
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
|
||||
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||
|
||||
plot_random_images(train_indices, train_loader)
|
||||
plot_random_images(test_indices, test_loader)
|
||||
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
||||
statistics.count_colors()
|
||||
statistics.print_statistics()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
kaggle==1.6.6
|
||||
matplotlib==3.6.3
|
||||
Pillow==9.3.0
|
||||
Pillow==10.2.0
|
||||
scikit_learn==1.2.2
|
||||
torch==2.0.0+cu117
|
||||
torchvision==0.15.1+cu117
|
Loading…
Reference in New Issue
Block a user