data file

This commit is contained in:
Karol Cyganik 2024-03-13 15:15:34 +01:00
parent b5f495c258
commit 7573648527
2 changed files with 155 additions and 0 deletions

2
.gitignore vendored
View File

@ -174,3 +174,5 @@ ipython_config.py
# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/
football_dataset/
venvium/

153
data.py Normal file
View File

@ -0,0 +1,153 @@
import os
from glob import glob
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from kaggle.api.kaggle_api_extended import KaggleApi
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
class FootballSegDataset(Dataset):
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None):
self.image_paths = image_paths
self.label_paths = label_paths
self.mean = mean
self.std = std
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
label_map = Image.open(self.label_paths[idx]).convert("RGB")
if self.transform:
image = self.transform(image)
label_map = self.transform(label_map)
return image, label_map
def get_data():
api = KaggleApi()
api.authenticate()
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
download_dir = "./football_dataset"
if not os.path.exists(download_dir):
os.makedirs(download_dir)
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
for path in all_images]
return all_images, all_paths
def calculate_mean_std(image_paths):
mean = torch.zeros(3)
std = torch.zeros(3)
total_images = 0
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
image = transforms.ToTensor()(image)
mean += torch.mean(image, dim=(1, 2))
std += torch.std(image, dim=(1, 2))
total_images += 1
mean /= total_images
std /= total_images
return mean, std
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256):
image = Image.open(path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
image = transform(image)
return image
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True):
if calculate_stats:
mean, std = calculate_mean_std(image_paths)
print("Mean:", mean)
print("Std:", std)
else:
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
images = []
label_maps = []
for image_path, label_path in zip(image_paths, label_paths):
image = load_image(image_path, mean, std, SIZE=SIZE)
label_map = load_image(label_path, mean, std, SIZE=SIZE)
images.append(image)
label_maps.append(label_map)
images = torch.stack(images)
label_maps = torch.stack(label_maps)
return images, label_maps
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
image_np = image.permute(1, 2, 0).cpu().numpy()
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
plt.imshow(image_np, alpha=alpha_1)
plt.imshow(label_map_np, alpha=alpha_2)
plt.axis('off')
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
n_rows, n_cols = GRID
n_images = n_rows * n_cols
plt.figure(figsize=SIZE)
i = 1
for image, label_map in zip(images, label_maps):
plt.subplot(n_rows, n_cols, i)
show_map(image, label_map)
i += 1
if i > n_images:
break
plt.show()
def main():
all_images, all_paths = get_data()
images, label_maps = load_data(
all_images, all_paths, calculate_stats=False)
show_maps(images, label_maps, GRID=[3, 3])
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
all_images, all_paths, test_size=0.2, random_state=42)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
0.2081, 0.2054, 0.2093])
])
train_dataset = FootballSegDataset(
image_train_paths, label_train_paths, transform=transform)
test_dataset = FootballSegDataset(
image_test_paths, label_test_paths, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
if __name__ == "__main__":
main()