rgb fix
This commit is contained in:
parent
b3617532f8
commit
d37f195ea6
@ -130,8 +130,12 @@ class DonutDataset(Dataset):
|
||||
"""
|
||||
sample = self.dataset[idx]
|
||||
|
||||
# change if not 3 channels
|
||||
if sample['image'].mode != "RGB":
|
||||
sample['image'] = sample['image'].convert("RGB")
|
||||
|
||||
# inputs
|
||||
pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
|
||||
pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.squeeze()
|
||||
|
||||
# targets
|
||||
|
Loading…
Reference in New Issue
Block a user