update
This commit is contained in:
parent
7573648527
commit
59ef151e53
170
data.py
170
data.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -6,17 +7,44 @@ import torch
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sklearn.model_selection import train_test_split
|
from torch.utils.data import Dataset
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class FootballSegDataset(Dataset):
|
class FootballSegDataset(Dataset):
|
||||||
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_paths,
|
||||||
|
label_paths,
|
||||||
|
mean=[0.3468, 0.3885, 0.3321],
|
||||||
|
std=[0.2081, 0.2054, 0.2093],
|
||||||
|
):
|
||||||
self.image_paths = image_paths
|
self.image_paths = image_paths
|
||||||
self.label_paths = label_paths
|
self.label_paths = label_paths
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
self.std = std
|
self.std = std
|
||||||
self.transform = transform
|
self.transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.num_classes = 11
|
||||||
|
self.class_names = [
|
||||||
|
"Goal Bar",
|
||||||
|
"Referee",
|
||||||
|
"Advertisement",
|
||||||
|
"Ground",
|
||||||
|
"Ball",
|
||||||
|
"Coaches & Officials",
|
||||||
|
"Team A",
|
||||||
|
"Team B",
|
||||||
|
"Goalkeeper A",
|
||||||
|
"Goalkeeper B",
|
||||||
|
"Audience",
|
||||||
|
]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_paths)
|
return len(self.image_paths)
|
||||||
@ -32,6 +60,45 @@ class FootballSegDataset(Dataset):
|
|||||||
return image, label_map
|
return image, label_map
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationStatistics:
|
||||||
|
def __init__(self, label_paths, train_loader, test_loader, mean, std):
|
||||||
|
self.label_paths = label_paths
|
||||||
|
self.color_counts = []
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.test_loader = test_loader
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def count_colors(self):
|
||||||
|
for path in self.label_paths:
|
||||||
|
label_map = Image.open(path).convert("RGB")
|
||||||
|
colors = set(label_map.getdata())
|
||||||
|
num_colors = len(colors)
|
||||||
|
self.color_counts.append(num_colors - 1)
|
||||||
|
|
||||||
|
def print_statistics(self):
|
||||||
|
max_classes = (
|
||||||
|
max(self.color_counts),
|
||||||
|
self.color_counts.index(max(self.color_counts)),
|
||||||
|
)
|
||||||
|
min_classes = (
|
||||||
|
min(self.color_counts),
|
||||||
|
self.color_counts.index(min(self.color_counts)),
|
||||||
|
)
|
||||||
|
avg_classes = sum(self.color_counts) / len(self.color_counts)
|
||||||
|
print(f"Train data size: {len(self.train_loader.dataset)}")
|
||||||
|
print(f"Test data size: {len(self.test_loader.dataset)}")
|
||||||
|
print(f"Data shape: {self.train_loader.dataset[0][0].shape}")
|
||||||
|
print(f"Label shape: {self.train_loader.dataset[0][1].shape}")
|
||||||
|
print(f"Mean: {self.mean}")
|
||||||
|
print(f"Std: {self.std}")
|
||||||
|
print(f"Number of classes: {self.train_loader.dataset.num_classes}")
|
||||||
|
print(f"Class names: {self.train_loader.dataset.class_names}")
|
||||||
|
print(f"Max classes: {max_classes[0]} at index {max_classes[1]}")
|
||||||
|
print(f"Min classes: {min_classes[0]} at index {min_classes[1]}")
|
||||||
|
print(f"Avg classes: {avg_classes}")
|
||||||
|
|
||||||
|
|
||||||
def get_data():
|
def get_data():
|
||||||
api = KaggleApi()
|
api = KaggleApi()
|
||||||
api.authenticate()
|
api.authenticate()
|
||||||
@ -45,8 +112,7 @@ def get_data():
|
|||||||
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
||||||
|
|
||||||
all_images = glob(download_dir + "/images/*.jpg")
|
all_images = glob(download_dir + "/images/*.jpg")
|
||||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
|
||||||
for path in all_images]
|
|
||||||
return all_images, all_paths
|
return all_images, all_paths
|
||||||
|
|
||||||
|
|
||||||
@ -68,86 +134,18 @@ def calculate_mean_std(image_paths):
|
|||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256):
|
def plot_random_images(indices, data_loader):
|
||||||
image = Image.open(path).convert("RGB")
|
plt.figure(figsize=(8, 8))
|
||||||
transform = transforms.Compose([
|
for i, index in enumerate(indices):
|
||||||
transforms.Resize((SIZE, SIZE)),
|
image, label_map = data_loader.dataset[index]
|
||||||
transforms.ToTensor(),
|
image = deepcopy(image).permute(1, 2, 0)
|
||||||
transforms.Normalize(mean=mean, std=std)
|
label_map = deepcopy(label_map).permute(1, 2, 0)
|
||||||
])
|
|
||||||
image = transform(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
plt.subplot(5, 2, i * 2 + 1)
|
||||||
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True):
|
plt.imshow(image)
|
||||||
if calculate_stats:
|
plt.axis("off")
|
||||||
mean, std = calculate_mean_std(image_paths)
|
plt.subplot(5, 2, i * 2 + 2)
|
||||||
print("Mean:", mean)
|
plt.imshow(label_map)
|
||||||
print("Std:", std)
|
plt.axis("off")
|
||||||
else:
|
|
||||||
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
|
|
||||||
|
|
||||||
images = []
|
|
||||||
label_maps = []
|
|
||||||
for image_path, label_path in zip(image_paths, label_paths):
|
|
||||||
image = load_image(image_path, mean, std, SIZE=SIZE)
|
|
||||||
label_map = load_image(label_path, mean, std, SIZE=SIZE)
|
|
||||||
images.append(image)
|
|
||||||
label_maps.append(label_map)
|
|
||||||
images = torch.stack(images)
|
|
||||||
label_maps = torch.stack(label_maps)
|
|
||||||
return images, label_maps
|
|
||||||
|
|
||||||
|
|
||||||
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
|
|
||||||
image_np = image.permute(1, 2, 0).cpu().numpy()
|
|
||||||
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
|
|
||||||
|
|
||||||
plt.imshow(image_np, alpha=alpha_1)
|
|
||||||
plt.imshow(label_map_np, alpha=alpha_2)
|
|
||||||
plt.axis('off')
|
|
||||||
|
|
||||||
|
|
||||||
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
|
|
||||||
n_rows, n_cols = GRID
|
|
||||||
n_images = n_rows * n_cols
|
|
||||||
plt.figure(figsize=SIZE)
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for image, label_map in zip(images, label_maps):
|
|
||||||
plt.subplot(n_rows, n_cols, i)
|
|
||||||
show_map(image, label_map)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
if i > n_images:
|
|
||||||
break
|
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
all_images, all_paths = get_data()
|
|
||||||
images, label_maps = load_data(
|
|
||||||
all_images, all_paths, calculate_stats=False)
|
|
||||||
show_maps(images, label_maps, GRID=[3, 3])
|
|
||||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
|
|
||||||
all_images, all_paths, test_size=0.2, random_state=42)
|
|
||||||
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.Resize((256, 256)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
|
|
||||||
0.2081, 0.2054, 0.2093])
|
|
||||||
])
|
|
||||||
|
|
||||||
train_dataset = FootballSegDataset(
|
|
||||||
image_train_paths, label_train_paths, transform=transform)
|
|
||||||
test_dataset = FootballSegDataset(
|
|
||||||
image_test_paths, label_test_paths, transform=transform)
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
||||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
40
main.py
Normal file
40
main.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from data import (
|
||||||
|
FootballSegDataset,
|
||||||
|
SegmentationStatistics,
|
||||||
|
calculate_mean_std,
|
||||||
|
get_data,
|
||||||
|
plot_random_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
all_images, all_paths = get_data()
|
||||||
|
image_train_paths, image_test_paths, label_train_paths, label_test_paths = (
|
||||||
|
train_test_split(all_images, all_paths, test_size=0.2, random_state=42)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean, std = calculate_mean_std(image_train_paths)
|
||||||
|
|
||||||
|
train_dataset = FootballSegDataset(image_train_paths, label_train_paths, mean, std)
|
||||||
|
test_dataset = FootballSegDataset(image_test_paths, label_test_paths, mean, std)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
|
train_indices = random.sample(range(len(train_loader.dataset)), 5)
|
||||||
|
test_indices = random.sample(range(len(test_loader.dataset)), 5)
|
||||||
|
|
||||||
|
plot_random_images(train_indices, train_loader)
|
||||||
|
plot_random_images(test_indices, test_loader)
|
||||||
|
statistics = SegmentationStatistics(all_paths, train_loader, test_loader, mean, std)
|
||||||
|
statistics.count_colors()
|
||||||
|
statistics.print_statistics()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
kaggle==1.6.6
|
||||||
|
matplotlib==3.6.3
|
||||||
|
Pillow==9.3.0
|
||||||
|
Pillow==10.2.0
|
||||||
|
scikit_learn==1.2.2
|
||||||
|
torch==2.0.0+cu117
|
||||||
|
torchvision==0.15.1+cu117
|
Loading…
Reference in New Issue
Block a user