fixing conversion of chanels, config update

This commit is contained in:
s444415 2023-01-05 13:42:46 +00:00
parent fe0b8c0397
commit b3617532f8
2 changed files with 4 additions and 5 deletions

View File

@ -5,7 +5,7 @@ output_model_path: "Zombely/plwiki-proto-fine-tuned-v3"
wandb_test_name: "fiszki-ocr-fine-tune" wandb_test_name: "fiszki-ocr-fine-tune"
checkpoint_path: "./checkpoint" checkpoint_path: "./checkpoint"
max_length: 768 max_length: 768
image_size: [1920, 2560] image_size: [2560, 1920]
train_config: train_config:
max_epochs: 1 max_epochs: 1
val_check_interval: 0.5 val_check_interval: 0.5

View File

@ -7,7 +7,6 @@ import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel from transformers import DonutProcessor, VisionEncoderDecoderModel
class DonutDataset(Dataset): class DonutDataset(Dataset):
""" """
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
@ -132,7 +131,7 @@ class DonutDataset(Dataset):
sample = self.dataset[idx] sample = self.dataset[idx]
# inputs # inputs
pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
pixel_values = pixel_values.squeeze() pixel_values = pixel_values.squeeze()
# targets # targets