donut eval fix
This commit is contained in:
parent
cfcb15e999
commit
31981cfc51
@ -22,7 +22,7 @@ def main(config):
|
|||||||
# config_vision.decoder.max_length = max_length
|
# config_vision.decoder.max_length = max_length
|
||||||
|
|
||||||
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
|
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
|
||||||
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config)
|
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path)
|
||||||
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
|
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user