This commit is contained in:
Michał Kozłowski 2023-01-05 14:50:49 +01:00
parent b3617532f8
commit d37f195ea6

View File

@ -130,8 +130,12 @@ class DonutDataset(Dataset):
""" """
sample = self.dataset[idx] sample = self.dataset[idx]
# change if not 3 channels
if sample['image'].mode != "RGB":
sample['image'] = sample['image'].convert("RGB")
# inputs # inputs
pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
pixel_values = pixel_values.squeeze() pixel_values = pixel_values.squeeze()
# targets # targets