diff --git a/donut-eval.py b/donut-eval.py index bc83f27..bdf7c93 100644 --- a/donut-eval.py +++ b/donut-eval.py @@ -15,14 +15,12 @@ from sconf import Config def main(config): - # max_length = 768 - # image_size = [1920, 2560] - # config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path) - # config_vision.encoder.image_size = image_size # (height, width) - # config_vision.decoder.max_length = max_length + config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path) + config_vision.encoder.image_size = [1920, 2560] # (height, width) + config_vision.decoder.max_length = 768 processor = DonutProcessor.from_pretrained(config.pretrained_processor_path) - model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path) + model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision) dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split) device = "cuda" if torch.cuda.is_available() else "cpu" model.eval()