From fe0b8c0397f24a20122f0cb5f2a0bb694e3b36ae Mon Sep 17 00:00:00 2001 From: s444415 Date: Wed, 4 Jan 2023 11:33:50 +0100 Subject: [PATCH] fixes, missing args --- donut-train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/donut-train.py b/donut-train.py index 139823b..02c3ff5 100644 --- a/donut-train.py +++ b/donut-train.py @@ -64,7 +64,7 @@ def main(config, hug_token): login(hug_token, True) - model_module = DonutModelPLModule(config.train_config.toDict(), processor, model) + model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) @@ -90,7 +90,7 @@ def main(config, hug_token): plugins=custom_ckpt, num_sanity_val_steps=0, logger=wandb_logger, - callbacks=[PushToHubCallback(), checkpoint_callback], + callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], ) trainer.fit(model_module)