fixes, missing args
This commit is contained in:
parent
c263335830
commit
fe0b8c0397
@ -64,7 +64,7 @@ def main(config, hug_token):
|
|||||||
|
|
||||||
login(hug_token, True)
|
login(hug_token, True)
|
||||||
|
|
||||||
model_module = DonutModelPLModule(config.train_config.toDict(), processor, model)
|
model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)
|
||||||
|
|
||||||
wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name)
|
wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name)
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ def main(config, hug_token):
|
|||||||
plugins=custom_ckpt,
|
plugins=custom_ckpt,
|
||||||
num_sanity_val_steps=0,
|
num_sanity_val_steps=0,
|
||||||
logger=wandb_logger,
|
logger=wandb_logger,
|
||||||
callbacks=[PushToHubCallback(), checkpoint_callback],
|
callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model_module)
|
trainer.fit(model_module)
|
||||||
|
Loading…
Reference in New Issue
Block a user