154 lines
4.7 KiB
Python
154 lines
4.7 KiB
Python
import os
|
|
from glob import glob
|
|
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
|
from PIL import Image
|
|
from sklearn.model_selection import train_test_split
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
class FootballSegDataset(Dataset):
|
|
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None):
|
|
self.image_paths = image_paths
|
|
self.label_paths = label_paths
|
|
self.mean = mean
|
|
self.std = std
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
image = Image.open(self.image_paths[idx]).convert("RGB")
|
|
label_map = Image.open(self.label_paths[idx]).convert("RGB")
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
label_map = self.transform(label_map)
|
|
|
|
return image, label_map
|
|
|
|
|
|
def get_data():
|
|
api = KaggleApi()
|
|
api.authenticate()
|
|
|
|
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
|
|
|
|
download_dir = "./football_dataset"
|
|
|
|
if not os.path.exists(download_dir):
|
|
os.makedirs(download_dir)
|
|
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
|
|
|
all_images = glob(download_dir + "/images/*.jpg")
|
|
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
|
for path in all_images]
|
|
return all_images, all_paths
|
|
|
|
|
|
def calculate_mean_std(image_paths):
|
|
mean = torch.zeros(3)
|
|
std = torch.zeros(3)
|
|
total_images = 0
|
|
|
|
for image_path in image_paths:
|
|
image = Image.open(image_path).convert("RGB")
|
|
image = transforms.ToTensor()(image)
|
|
mean += torch.mean(image, dim=(1, 2))
|
|
std += torch.std(image, dim=(1, 2))
|
|
total_images += 1
|
|
|
|
mean /= total_images
|
|
std /= total_images
|
|
|
|
return mean, std
|
|
|
|
|
|
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256):
|
|
image = Image.open(path).convert("RGB")
|
|
transform = transforms.Compose([
|
|
transforms.Resize((SIZE, SIZE)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=mean, std=std)
|
|
])
|
|
image = transform(image)
|
|
return image
|
|
|
|
|
|
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True):
|
|
if calculate_stats:
|
|
mean, std = calculate_mean_std(image_paths)
|
|
print("Mean:", mean)
|
|
print("Std:", std)
|
|
else:
|
|
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
|
|
|
|
images = []
|
|
label_maps = []
|
|
for image_path, label_path in zip(image_paths, label_paths):
|
|
image = load_image(image_path, mean, std, SIZE=SIZE)
|
|
label_map = load_image(label_path, mean, std, SIZE=SIZE)
|
|
images.append(image)
|
|
label_maps.append(label_map)
|
|
images = torch.stack(images)
|
|
label_maps = torch.stack(label_maps)
|
|
return images, label_maps
|
|
|
|
|
|
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
|
|
image_np = image.permute(1, 2, 0).cpu().numpy()
|
|
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
|
|
|
|
plt.imshow(image_np, alpha=alpha_1)
|
|
plt.imshow(label_map_np, alpha=alpha_2)
|
|
plt.axis('off')
|
|
|
|
|
|
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
|
|
n_rows, n_cols = GRID
|
|
n_images = n_rows * n_cols
|
|
plt.figure(figsize=SIZE)
|
|
|
|
i = 1
|
|
for image, label_map in zip(images, label_maps):
|
|
plt.subplot(n_rows, n_cols, i)
|
|
show_map(image, label_map)
|
|
|
|
i += 1
|
|
if i > n_images:
|
|
break
|
|
|
|
plt.show()
|
|
|
|
|
|
def main():
|
|
all_images, all_paths = get_data()
|
|
images, label_maps = load_data(
|
|
all_images, all_paths, calculate_stats=False)
|
|
show_maps(images, label_maps, GRID=[3, 3])
|
|
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
|
|
all_images, all_paths, test_size=0.2, random_state=42)
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
|
|
0.2081, 0.2054, 0.2093])
|
|
])
|
|
|
|
train_dataset = FootballSegDataset(
|
|
image_train_paths, label_train_paths, transform=transform)
|
|
test_dataset = FootballSegDataset(
|
|
image_test_paths, label_test_paths, transform=transform)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|