fixes, missing args

This commit is contained in:
s444415 2023-01-04 11:33:50 +01:00
parent c263335830
commit fe0b8c0397

View File

@ -64,7 +64,7 @@ def main(config, hug_token):
login(hug_token, True) login(hug_token, True)
model_module = DonutModelPLModule(config.train_config.toDict(), processor, model) model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)
wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name)
@ -90,7 +90,7 @@ def main(config, hug_token):
plugins=custom_ckpt, plugins=custom_ckpt,
num_sanity_val_steps=0, num_sanity_val_steps=0,
logger=wandb_logger, logger=wandb_logger,
callbacks=[PushToHubCallback(), checkpoint_callback], callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback],
) )
trainer.fit(model_module) trainer.fit(model_module)