diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index f66b05d..dc40e6c 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -130,8 +130,12 @@ class DonutDataset(Dataset): """ sample = self.dataset[idx] + # change if not 3 channels + if sample['image'].mode != "RGB": + sample['image'] = sample['image'].convert("RGB") + # inputs - pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values + pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values pixel_values = pixel_values.squeeze() # targets