wozek/ai-wozek/siec/dataset.py

41 lines
1.0 KiB
Python
Raw Normal View History

2024-06-17 04:58:21 +02:00
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm
import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys
class LabelDataset(Dataset):
def __init__(self, dataDirectory,transform=None):
self.data=ImageFolder(dataDirectory,transform=transform)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
@property
def classes(self):
return self.data.classes
classname={v:k for k,v in ImageFolder('./train').class_to_idx.items()}
transform=transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor(),
])
dataset=LabelDataset('./train',transform)
dataloader=DataLoader(dataset,batch_size=4, shuffle=True)
valset=LabelDataset('./val',transform)
valloader=DataLoader(valset,batch_size=4,shuffle=False)
for images,labels in dataloader:
break