donut eval fix

This commit is contained in:
Michał Kozłowski 2022-12-17 10:27:14 +01:00
parent cfcb15e999
commit 31981cfc51

View File

@ -22,7 +22,7 @@ def main(config):
# config_vision.decoder.max_length = max_length # config_vision.decoder.max_length = max_length
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path) processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config) model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path)
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split) dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval() model.eval()