fixing conversion of chanels, config update
This commit is contained in:
parent
fe0b8c0397
commit
b3617532f8
@ -5,10 +5,10 @@ output_model_path: "Zombely/plwiki-proto-fine-tuned-v3"
|
|||||||
wandb_test_name: "fiszki-ocr-fine-tune"
|
wandb_test_name: "fiszki-ocr-fine-tune"
|
||||||
checkpoint_path: "./checkpoint"
|
checkpoint_path: "./checkpoint"
|
||||||
max_length: 768
|
max_length: 768
|
||||||
image_size: [1920, 2560]
|
image_size: [2560, 1920]
|
||||||
train_config:
|
train_config:
|
||||||
max_epochs: 1
|
max_epochs: 1
|
||||||
val_check_interval: 0.5
|
val_check_interval: 0.5
|
||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
gradient_clip_val: 1.0
|
gradient_clip_val: 1.0
|
||||||
num_training_samples_per_epoch: 800
|
num_training_samples_per_epoch: 800
|
||||||
|
@ -7,7 +7,6 @@ import torch
|
|||||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DonutDataset(Dataset):
|
class DonutDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
|
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
|
||||||
@ -132,7 +131,7 @@ class DonutDataset(Dataset):
|
|||||||
sample = self.dataset[idx]
|
sample = self.dataset[idx]
|
||||||
|
|
||||||
# inputs
|
# inputs
|
||||||
pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
pixel_values = self.processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
|
||||||
pixel_values = pixel_values.squeeze()
|
pixel_values = pixel_values.squeeze()
|
||||||
|
|
||||||
# targets
|
# targets
|
||||||
@ -149,4 +148,4 @@ class DonutDataset(Dataset):
|
|||||||
labels = input_ids.clone()
|
labels = input_ids.clone()
|
||||||
labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
|
labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
|
||||||
# labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
|
# labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
|
||||||
return pixel_values, labels, target_sequence
|
return pixel_values, labels, target_sequence
|
||||||
|
Loading…
Reference in New Issue
Block a user