workaround for .kaggle
This commit is contained in:
parent
857782f0cf
commit
5d2a5fc9a4
15
data.py
15
data.py
@ -5,7 +5,6 @@ from glob import glob
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
@ -100,20 +99,20 @@ class SegmentationStatistics:
|
|||||||
|
|
||||||
|
|
||||||
def get_data():
|
def get_data():
|
||||||
api = KaggleApi()
|
|
||||||
api.authenticate()
|
|
||||||
|
|
||||||
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
|
dataset_slug = "sadhliroomyprime/football-semantic-segmentation"
|
||||||
|
|
||||||
download_dir = "./football_dataset"
|
download_dir = "./football_dataset"
|
||||||
|
|
||||||
if not os.path.exists(download_dir):
|
if not os.path.exists(download_dir):
|
||||||
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
|
|
||||||
|
api = KaggleApi()
|
||||||
|
api.authenticate()
|
||||||
os.makedirs(download_dir)
|
os.makedirs(download_dir)
|
||||||
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
api.dataset_download_files(dataset_slug, path=download_dir, unzip=True)
|
||||||
|
|
||||||
all_images = glob(download_dir + "/images/*.jpg")
|
all_images = glob(download_dir + "/images/*.jpg")
|
||||||
all_paths = [path.replace(".jpg", ".jpg___fuse.png")
|
all_paths = [path.replace(".jpg", ".jpg___fuse.png") for path in all_images]
|
||||||
for path in all_images]
|
|
||||||
return all_images, all_paths
|
return all_images, all_paths
|
||||||
|
|
||||||
|
|
||||||
@ -135,7 +134,7 @@ def calculate_mean_std(image_paths):
|
|||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
def plot_random_images(indices, data_loader, suffix='train'):
|
def plot_random_images(indices, data_loader, suffix="train"):
|
||||||
plt.figure(figsize=(8, 8))
|
plt.figure(figsize=(8, 8))
|
||||||
for i, index in enumerate(indices):
|
for i, index in enumerate(indices):
|
||||||
image, label_map = data_loader.dataset[index]
|
image, label_map = data_loader.dataset[index]
|
||||||
@ -150,5 +149,5 @@ def plot_random_images(indices, data_loader, suffix='train'):
|
|||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
|
|
||||||
# Save the figure to a file instead of displaying it
|
# Save the figure to a file instead of displaying it
|
||||||
plt.savefig(f'random_images_{suffix}.png', dpi=250, bbox_inches='tight')
|
plt.savefig(f"random_images_{suffix}.png", dpi=250, bbox_inches="tight")
|
||||||
plt.close() # Close the figure to free up memory
|
plt.close() # Close the figure to free up memory
|
||||||
|
Loading…
Reference in New Issue
Block a user