29 lines
810 B
Python
29 lines
810 B
Python
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
import pandas as pd
|
||
|
from torchvision.io import read_image, ImageReadMode
|
||
|
from common.helpers import createCSV
|
||
|
import os
|
||
|
|
||
|
|
||
|
class WaterSandTreeGrass(Dataset):
|
||
|
def __init__(self, annotations_file, img_dir, transform=None):
|
||
|
createCSV()
|
||
|
self.img_labels = pd.read_csv(annotations_file)
|
||
|
self.img_dir = img_dir
|
||
|
self.transform = transform
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.img_labels)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
|
||
|
image = read_image(img_path, mode=ImageReadMode.RGB)
|
||
|
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
||
|
|
||
|
if self.transform:
|
||
|
image = self.transform(image)
|
||
|
|
||
|
return image, label
|
||
|
|