added more config to eval

This commit is contained in:
Michał Kozłowski 2022-12-17 10:33:03 +01:00
parent bdb7f5ef7e
commit 296647793e

View File

@ -15,12 +15,17 @@ from sconf import Config
def main(config): def main(config):
image_size = [1920, 2560]
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path) config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
config_vision.encoder.image_size = [1920, 2560] # (height, width) config_vision.encoder.image_size = image_size # (height, width)
config_vision.decoder.max_length = 768 config_vision.decoder.max_length = 768
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path) processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision) model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision)
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split) dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval() model.eval()