data file
This commit is contained in:
parent
b5f495c258
commit
7573648527
2
.gitignore
vendored
2
.gitignore
vendored
@ -174,3 +174,5 @@ ipython_config.py
|
||||
# Remove previous ipynb_checkpoints
|
||||
# git rm -r .ipynb_checkpoints/
|
||||
|
||||
football_dataset/
|
||||
venvium/
|
153
data.py
Normal file
153
data.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||
from PIL import Image
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
class FootballSegDataset(Dataset):
|
||||
def __init__(self, image_paths, label_paths, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], transform=None):
|
||||
self.image_paths = image_paths
|
||||
self.label_paths = label_paths
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image = Image.open(self.image_paths[idx]).convert("RGB")
|
||||
label_map = Image.open(self.label_paths[idx]).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
label_map = self.transform(label_map)
|
||||
|
||||
return image, label_map
|
||||
|
||||
|
||||
def get_data():
|
||||
api = KaggleApi()
|
||||
api.authenticate()
|
||||
|
||||
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
|
||||
|
||||
download_dir = "./football_dataset"
|
||||
|
||||
if not os.path.exists(download_dir):
|
||||
os.makedirs(download_dir)
|
||||
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
||||
|
||||
all_images = glob(download_dir + "/images/*.jpg")
|
||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
||||
for path in all_images]
|
||||
return all_images, all_paths
|
||||
|
||||
|
||||
def calculate_mean_std(image_paths):
|
||||
mean = torch.zeros(3)
|
||||
std = torch.zeros(3)
|
||||
total_images = 0
|
||||
|
||||
for image_path in image_paths:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = transforms.ToTensor()(image)
|
||||
mean += torch.mean(image, dim=(1, 2))
|
||||
std += torch.std(image, dim=(1, 2))
|
||||
total_images += 1
|
||||
|
||||
mean /= total_images
|
||||
std /= total_images
|
||||
|
||||
return mean, std
|
||||
|
||||
|
||||
def load_image(path, mean=[0.3468, 0.3885, 0.3321], std=[0.2081, 0.2054, 0.2093], SIZE=256):
|
||||
image = Image.open(path).convert("RGB")
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((SIZE, SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
image = transform(image)
|
||||
return image
|
||||
|
||||
|
||||
def load_data(image_paths, label_paths, SIZE=256, calculate_stats=True):
|
||||
if calculate_stats:
|
||||
mean, std = calculate_mean_std(image_paths)
|
||||
print("Mean:", mean)
|
||||
print("Std:", std)
|
||||
else:
|
||||
mean, std = [0.3468, 0.3885, 0.3321], [0.2081, 0.2054, 0.2093]
|
||||
|
||||
images = []
|
||||
label_maps = []
|
||||
for image_path, label_path in zip(image_paths, label_paths):
|
||||
image = load_image(image_path, mean, std, SIZE=SIZE)
|
||||
label_map = load_image(label_path, mean, std, SIZE=SIZE)
|
||||
images.append(image)
|
||||
label_maps.append(label_map)
|
||||
images = torch.stack(images)
|
||||
label_maps = torch.stack(label_maps)
|
||||
return images, label_maps
|
||||
|
||||
|
||||
def show_map(image, label_map, alpha_1=1, alpha_2=0.7):
|
||||
image_np = image.permute(1, 2, 0).cpu().numpy()
|
||||
label_map_np = label_map.permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
plt.imshow(image_np, alpha=alpha_1)
|
||||
plt.imshow(label_map_np, alpha=alpha_2)
|
||||
plt.axis('off')
|
||||
|
||||
|
||||
def show_maps(images, label_maps, GRID=[5, 6], SIZE=(25, 25)):
|
||||
n_rows, n_cols = GRID
|
||||
n_images = n_rows * n_cols
|
||||
plt.figure(figsize=SIZE)
|
||||
|
||||
i = 1
|
||||
for image, label_map in zip(images, label_maps):
|
||||
plt.subplot(n_rows, n_cols, i)
|
||||
show_map(image, label_map)
|
||||
|
||||
i += 1
|
||||
if i > n_images:
|
||||
break
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
all_images, all_paths = get_data()
|
||||
images, label_maps = load_data(
|
||||
all_images, all_paths, calculate_stats=False)
|
||||
show_maps(images, label_maps, GRID=[3, 3])
|
||||
image_train_paths, image_test_paths, label_train_paths, label_test_paths = train_test_split(
|
||||
all_images, all_paths, test_size=0.2, random_state=42)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.3468, 0.3885, 0.3321], std=[
|
||||
0.2081, 0.2054, 0.2093])
|
||||
])
|
||||
|
||||
train_dataset = FootballSegDataset(
|
||||
image_train_paths, label_train_paths, transform=transform)
|
||||
test_dataset = FootballSegDataset(
|
||||
image_test_paths, label_test_paths, transform=transform)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user