import torch
from torch.utils.data import Dataset
import pandas as pd
from torchvision.io import read_image, ImageReadMode
from common.helpers import createCSV
from PIL import Image


class WaterSandTreeGrass(Dataset):
    def __init__(self, annotations_file, transform=None):
        createCSV()
        self.img_labels = pd.read_csv(annotations_file)
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = Image.open(self.img_labels.iloc[idx, 0]).convert('RGB')

        label = torch.tensor(int(self.img_labels.iloc[idx, 1]))

        if self.transform:
            image = self.transform(image)

        return image, label