2022-05-17 22:54:56 +02:00
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
import pandas as pd
|
|
|
|
from torchvision.io import read_image, ImageReadMode
|
|
|
|
from common.helpers import createCSV
|
|
|
|
|
|
|
|
|
|
|
|
class WaterSandTreeGrass(Dataset):
|
2022-05-18 10:29:05 +02:00
|
|
|
def __init__(self, annotations_file, transform=None):
|
2022-05-17 22:54:56 +02:00
|
|
|
createCSV()
|
|
|
|
self.img_labels = pd.read_csv(annotations_file)
|
|
|
|
self.transform = transform
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.img_labels)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2022-05-18 10:29:05 +02:00
|
|
|
image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB)
|
2022-05-17 22:54:56 +02:00
|
|
|
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
|
|
|
|
|
|
|
if self.transform:
|
|
|
|
image = self.transform(image)
|
|
|
|
|
|
|
|
return image, label
|
|
|
|
|