diff --git a/donut-train.py b/donut-train.py index 139823b..02c3ff5 100644 --- a/donut-train.py +++ b/donut-train.py @@ -64,7 +64,7 @@ def main(config, hug_token): login(hug_token, True) - model_module = DonutModelPLModule(config.train_config.toDict(), processor, model) + model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) @@ -90,7 +90,7 @@ def main(config, hug_token): plugins=custom_ckpt, num_sanity_val_steps=0, logger=wandb_logger, - callbacks=[PushToHubCallback(), checkpoint_callback], + callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], ) trainer.fit(model_module)