workaround for .kaggle

This commit is contained in:
Karol Cyganik 2024-04-03 09:12:18 +02:00
parent 857782f0cf
commit 5d2a5fc9a4

15
data.py
View File

@ -5,7 +5,6 @@ from glob import glob
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from kaggle.api.kaggle_api_extended import KaggleApi
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -100,20 +99,20 @@ class SegmentationStatistics:
def get_data(): def get_data():
api = KaggleApi()
api.authenticate()
dataset_slug = "sadhliroomyprime/football-semantic-segmentation" dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
download_dir = "./football_dataset" download_dir = "./football_dataset"
if not os.path.exists(download_dir): if not os.path.exists(download_dir):
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
os.makedirs(download_dir) os.makedirs(download_dir)
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
all_images = glob(download_dir + "/images/*.jpg") all_images = glob(download_dir + "/images/*.jpg")
all_paths = [path.replace(".jpg", ".jpg___fuse.png") all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
for path in all_images]
return all_images, all_paths return all_images, all_paths
@ -135,7 +134,7 @@ def calculate_mean_std(image_paths):
return mean, std return mean, std
def plot_random_images(indices, data_loader, suffix='train'): def plot_random_images(indices, data_loader, suffix="train"):
plt.figure(figsize=(8, 8)) plt.figure(figsize=(8, 8))
for i, index in enumerate(indices): for i, index in enumerate(indices):
image, label_map = data_loader.dataset[index] image, label_map = data_loader.dataset[index]
@ -150,5 +149,5 @@ def plot_random_images(indices, data_loader, suffix='train'):
plt.axis("off") plt.axis("off")
# Save the figure to a file instead of displaying it # Save the figure to a file instead of displaying it
plt.savefig(f'random_images_{suffix}.png', dpi=250, bbox_inches='tight') plt.savefig(f"random_images_{suffix}.png", dpi=250, bbox_inches="tight")
plt.close() # Close the figure to free up memory plt.close() # Close the figure to free up memory