config for model
This commit is contained in:
parent
b7296bb2a9
commit
de8f89ddb1
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
@ -14,8 +14,15 @@ import argparse
|
|||||||
from sconf import Config
|
from sconf import Config
|
||||||
|
|
||||||
def main(config):
|
def main(config):
|
||||||
|
|
||||||
|
max_length = 768
|
||||||
|
image_size = [1920, 2560]
|
||||||
|
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
|
||||||
|
config_vision.encoder.image_size = image_size # (height, width)
|
||||||
|
config_vision.decoder.max_length = max_length
|
||||||
|
|
||||||
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
|
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
|
||||||
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path)
|
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config)
|
||||||
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
|
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user