rgb fix
This commit is contained in:
parent
b3617532f8
commit
d37f195ea6
@ -130,8 +130,12 @@ class DonutDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
sample = self.dataset[idx]
|
sample = self.dataset[idx]
|
||||||
|
|
||||||
|
# change if not 3 channels
|
||||||
|
if sample['image'].mode != "RGB":
|
||||||
|
sample['image'] = sample['image'].convert("RGB")
|
||||||
|
|
||||||
# inputs
|
# inputs
|
||||||
pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
|
pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
||||||
pixel_values = pixel_values.squeeze()
|
pixel_values = pixel_values.squeeze()
|
||||||
|
|
||||||
# targets
|
# targets
|
||||||
|
Loading…
Reference in New Issue
Block a user