import torch from torch.utils.data import Dataset import pandas as pd from torchvision.io import read_image, ImageReadMode from common.helpers import createCSV class WaterSandTreeGrass(Dataset): def __init__(self, annotations_file, transform=None): createCSV() self.img_labels = pd.read_csv(annotations_file) self.transform = transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB) label = torch.tensor(int(self.img_labels.iloc[idx, 1])) if self.transform: image = self.transform(image) return image, label