import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import torchvision import torchvision.transforms as transforms from torchvision.datasets import ImageFolder import timm import matplotlib.pyplot as plt # For data viz import pandas as pd import numpy as np import sys class LabelDataset(Dataset): def __init__(self, dataDirectory,transform=None): self.data=ImageFolder(dataDirectory,transform=transform) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] @property def classes(self): return self.data.classes classname={v:k for k,v in ImageFolder('./train').class_to_idx.items()} transform=transforms.Compose([ transforms.Resize((128,128)), transforms.ToTensor(), ]) dataset=LabelDataset('./train',transform) dataloader=DataLoader(dataset,batch_size=4, shuffle=True) valset=LabelDataset('./val',transform) valloader=DataLoader(valset,batch_size=4,shuffle=False) for images,labels in dataloader: break