diff --git a/data.py b/data.py index 302d047..e180905 100644 --- a/data.py +++ b/data.py @@ -5,7 +5,6 @@ from glob import glob import matplotlib.pyplot as plt import torch import torchvision.transforms as transforms -from kaggle.api.kaggle_api_extended import KaggleApi from PIL import Image from torch.utils.data import Dataset @@ -100,20 +99,20 @@ class SegmentationStatistics: def get_data(): - api = KaggleApi() - api.authenticate() - dataset_slug = "sadhliroomyprime/football-semantic-segmentation" download_dir = "./football_dataset" if not os.path.exists(download_dir): + from kaggle.api.kaggle_api_extended import KaggleApi + + api = KaggleApi() + api.authenticate() os.makedirs(download_dir) api.dataset_download_files(dataset_slug, path=download_dir, unzip=True) all_images = glob(download_dir + "/images/*.jpg") - all_paths = [path.replace(".jpg", ".jpg___fuse.png") - for path in all_images] + all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images] return all_images, all_paths @@ -135,7 +134,7 @@ def calculate_mean_std(image_paths): return mean, std -def plot_random_images(indices, data_loader, suffix='train'): +def plot_random_images(indices, data_loader, suffix="train"): plt.figure(figsize=(8, 8)) for i, index in enumerate(indices): image, label_map = data_loader.dataset[index] @@ -150,5 +149,5 @@ def plot_random_images(indices, data_loader, suffix='train'): plt.axis("off") # Save the figure to a file instead of displaying it - plt.savefig(f'random_images_{suffix}.png', dpi=250, bbox_inches='tight') + plt.savefig(f"random_images_{suffix}.png", dpi=250, bbox_inches="tight") plt.close() # Close the figure to free up memory