WMICraft/algorithms/neural_network/watersandtreegrass.py

29 lines
810 B
Python
Raw Normal View History

import torch
from torch.utils.data import Dataset
import pandas as pd
from torchvision.io import read_image, ImageReadMode
from common.helpers import createCSV
import os
class WaterSandTreeGrass(Dataset):
def __init__(self, annotations_file, img_dir, transform=None):
createCSV()
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path, mode=ImageReadMode.RGB)
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
if self.transform:
image = self.transform(image)
return image, label