64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
|
import torch
|
||
|
from torchvision import datasets
|
||
|
import pandas as pd
|
||
|
import torchvision.io as io
|
||
|
import os
|
||
|
|
||
|
class GlassesDataset(torch.utils.data.Dataset):
|
||
|
def __init__(self, img_dir, labels_dir, transform=None, target_transform=None):
|
||
|
self.img_dir = img_dir
|
||
|
self.glasses = []
|
||
|
self.transform = transform
|
||
|
self.target_transform = target_transform
|
||
|
|
||
|
data = pd.read_csv(labels_dir, low_memory=False)
|
||
|
rows_with_glasses = data[data['glasses'] == 1]
|
||
|
self.glasses = rows_with_glasses['id'].values
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
image = io.read_image(os.path.join(self.img_dir, f'face-{index}.png'), mode=io.ImageReadMode.UNCHANGED)
|
||
|
if index in self.glasses:
|
||
|
label = 1
|
||
|
else:
|
||
|
label = 0
|
||
|
|
||
|
|
||
|
if self.transform:
|
||
|
image = self.transform(image)
|
||
|
if self.target_transform:
|
||
|
label = self.target_transform(label)
|
||
|
|
||
|
return image, label
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.glasses)
|
||
|
|
||
|
gd = GlassesDataset('classes\\Jimmy_Neuron\\train', 'classes\\Jimmy_Neuron\\sdg.csv')
|
||
|
|
||
|
dat = gd.__getitem__(1)
|
||
|
|
||
|
print(dat)
|
||
|
|
||
|
# ts = GlassesDataset('classes\\Jimmy_Neuron\\test', 'classes\\Jimmy_Neuron\\set.csv')
|
||
|
|
||
|
# dat = ts.__getitem__(1)
|
||
|
|
||
|
# print(dat)
|
||
|
|
||
|
# from PIL import Image
|
||
|
# import PIL as pil
|
||
|
# import os
|
||
|
|
||
|
# def resize_image(image_path, output_path, size):
|
||
|
# image = Image.open(image_path)
|
||
|
# image = image.resize(size, Image.Resampling.LANCZOS)
|
||
|
# image.save(output_path)
|
||
|
|
||
|
|
||
|
# # rename all files in folder train from face-x.png to face-{x-1}.png in range 1-4000
|
||
|
|
||
|
# for i in range(1, 4300):
|
||
|
# old_name = os.path.join('classes\\Jimmy_Neuron\\train', f'face-{i}.png')
|
||
|
# new_name = os.path.join('classes\\Jimmy_Neuron\\train', f'face-{i-1}.png')
|
||
|
# os.rename(old_name, new_name)
|