diff --git a/config-train.yaml b/config-train.yaml index a239aa2..02004fb 100644 --- a/config-train.yaml +++ b/config-train.yaml @@ -5,10 +5,10 @@ output_model_path: "Zombely/plwiki-proto-fine-tuned-v3" wandb_test_name: "fiszki-ocr-fine-tune" checkpoint_path: "./checkpoint" max_length: 768 -image_size: [1920, 2560] +image_size: [2560, 1920] train_config: max_epochs: 1 - val_check_interval: 0.5 + val_check_interval: 0.5 check_val_every_n_epoch: 1 gradient_clip_val: 1.0 num_training_samples_per_epoch: 800 diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index 2d2a502..f66b05d 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -7,7 +7,6 @@ import torch from transformers import DonutProcessor, VisionEncoderDecoderModel - class DonutDataset(Dataset): """ DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) @@ -132,7 +131,7 @@ class DonutDataset(Dataset): sample = self.dataset[idx] # inputs - pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values + pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values pixel_values = pixel_values.squeeze() # targets @@ -149,4 +148,4 @@ class DonutDataset(Dataset): labels = input_ids.clone() labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA) - return pixel_values, labels, target_sequence \ No newline at end of file + return pixel_values, labels, target_sequence