28 lines
718 B
Python
28 lines
718 B
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
import pandas as pd
|
|
from torchvision.io import read_image, ImageReadMode
|
|
from common.helpers import createCSV
|
|
from PIL import Image
|
|
|
|
|
|
class WaterSandTreeGrass(Dataset):
|
|
def __init__(self, annotations_file, transform=None):
|
|
createCSV()
|
|
self.img_labels = pd.read_csv(annotations_file)
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.img_labels)
|
|
|
|
def __getitem__(self, idx):
|
|
image = Image.open(self.img_labels.iloc[idx, 0]).convert('RGB')
|
|
|
|
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
return image, label
|
|
|