WMICraft/algorithms/neural_network/watersandtreegrass.py
2022-05-18 10:29:05 +02:00

26 lines
704 B
Python

import torch
from torch.utils.data import Dataset
import pandas as pd
from torchvision.io import read_image, ImageReadMode
from common.helpers import createCSV
class WaterSandTreeGrass(Dataset):
def __init__(self, annotations_file, transform=None):
createCSV()
self.img_labels = pd.read_csv(annotations_file)
self.transform = transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB)
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
if self.transform:
image = self.transform(image)
return image, label