41 lines
1.0 KiB
Python
41 lines
1.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torchvision.datasets import ImageFolder
|
|
import timm
|
|
|
|
import matplotlib.pyplot as plt # For data viz
|
|
import pandas as pd
|
|
import numpy as np
|
|
import sys
|
|
|
|
class LabelDataset(Dataset):
|
|
def __init__(self, dataDirectory,transform=None):
|
|
self.data=ImageFolder(dataDirectory,transform=transform)
|
|
def __len__(self):
|
|
return len(self.data)
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|
|
|
|
@property
|
|
def classes(self):
|
|
return self.data.classes
|
|
|
|
classname={v:k for k,v in ImageFolder('./train').class_to_idx.items()}
|
|
|
|
transform=transforms.Compose([
|
|
transforms.Resize((128,128)),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
dataset=LabelDataset('./train',transform)
|
|
dataloader=DataLoader(dataset,batch_size=4, shuffle=True)
|
|
valset=LabelDataset('./val',transform)
|
|
valloader=DataLoader(valset,batch_size=4,shuffle=False)
|
|
|
|
for images,labels in dataloader:
|
|
break
|