config for model

This commit is contained in:
Michał Kozłowski 2022-12-16 15:20:10 +01:00
parent b7296bb2a9
commit de8f89ddb1

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
from transformers import DonutProcessor, VisionEncoderDecoderModel from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from datasets import load_dataset from datasets import load_dataset
import re import re
import json import json
@ -14,8 +14,15 @@ import argparse
from sconf import Config from sconf import Config
def main(config): def main(config):
max_length = 768
image_size = [1920, 2560]
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
config_vision.encoder.image_size = image_size # (height, width)
config_vision.decoder.max_length = max_length
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path) processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path) model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config)
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split) dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval() model.eval()