From d37f195ea6ec81231b8424e81fa5cfb965ffdffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Thu, 5 Jan 2023 14:50:49 +0100 Subject: [PATCH] rgb fix --- utils/donut_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index f66b05d..dc40e6c 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -130,8 +130,12 @@ class DonutDataset(Dataset): """ sample = self.dataset[idx] + # change if not 3 channels + if sample['image'].mode != "RGB": + sample['image'] = sample['image'].convert("RGB") + # inputs - pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values + pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values pixel_values = pixel_values.squeeze() # targets