import torch from torch.utils.data import Dataset import pandas as pd from torchvision.io import read_image, ImageReadMode from common.helpers import createCSV import os class WaterSandTreeGrass(Dataset): def __init__(self, annotations_file, img_dir, transform=None): createCSV() self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path, mode=ImageReadMode.RGB) label = torch.tensor(int(self.img_labels.iloc[idx, 1])) if self.transform: image = self.transform(image) return image, label