diff --git a/train_stream.py b/train_stream.py index 21eb5ab..95c8fd7 100644 --- a/train_stream.py +++ b/train_stream.py @@ -35,71 +35,70 @@ def main(config, hug_token): added_tokens = [] dataset = load_dataset(config.dataset_path) - dataset.train_test_split(test_size=0.1) - print(dataset) + dataset = dataset.train_test_split(test_size=0.1) - # train_dataset = DonutDataset( - # dataset, - # processor=processor, - # model=model, - # max_length=config.max_length, - # split="train", - # task_start_token="", - # prompt_end_token="", - # added_tokens=added_tokens, - # sort_json_key=False, # cord dataset is preprocessed, so no need for this - # ) + train_dataset = DonutDataset( + dataset, + processor=processor, + model=model, + max_length=config.max_length, + split="train", + task_start_token="", + prompt_end_token="", + added_tokens=added_tokens, + sort_json_key=False, # cord dataset is preprocessed, so no need for this + ) - # val_dataset = DonutDataset( - # dataset, - # processor=processor, - # model=model, - # max_length=config.max_length, - # split="validation", - # task_start_token="", - # prompt_end_token="", - # added_tokens=added_tokens, - # sort_json_key=False, # cord dataset is preprocessed, so no need for this - # ) + val_dataset = DonutDataset( + dataset, + processor=processor, + model=model, + max_length=config.max_length, + split="test", + task_start_token="", + prompt_end_token="", + added_tokens=added_tokens, + sort_json_key=False, # cord dataset is preprocessed, so no need for this + ) - # model.config.pad_token_id = processor.tokenizer.pad_token_id - # model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] + model.config.pad_token_id = processor.tokenizer.pad_token_id + model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] - # train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) - # val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) + train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) + val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) - # login(hug_token, True) + login(hug_token, True) - # model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) + model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) - # wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) + wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) - # checkpoint_callback = ModelCheckpoint( - # monitor="val_metric", - # dirpath=config.checkpoint_path, - # filename="artifacts", - # save_top_k=1, - # save_last=False, - # mode="min", - # ) + checkpoint_callback = ModelCheckpoint( + monitor="val_metric", + dirpath=config.checkpoint_path, + filename="artifacts", + save_top_k=1, + save_last=False, + mode="min", + ) - # custom_ckpt = CustomCheckpointIO() + custom_ckpt = CustomCheckpointIO() - # trainer = pl.Trainer( - # accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu - # devices=1, - # max_epochs=config.train_config.max_epochs, - # val_check_interval=config.train_config.val_check_interval, - # check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, - # gradient_clip_val=config.train_config.gradient_clip_val, - # precision=16, # we'll use mixed precision - # plugins=custom_ckpt, - # num_sanity_val_steps=0, - # logger=wandb_logger, - # callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], - # ) + trainer = pl.Trainer( + accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu + devices=1, + max_epochs=config.train_config.max_epochs, + val_check_interval=config.train_config.val_check_interval, + check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, + gradient_clip_val=config.train_config.gradient_clip_val, + precision=16, # we'll use mixed precision + plugins=custom_ckpt, + num_sanity_val_steps=0, + logger=wandb_logger, + callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], + ) - # trainer.fit(model_module) + trainer.fit(model_module) if __name__ == "__main__": diff --git a/utils/donut_dataset_stream.py b/utils/donut_dataset_stream.py index 8abc1f7..46ddaaa 100644 --- a/utils/donut_dataset_stream.py +++ b/utils/donut_dataset_stream.py @@ -47,7 +47,7 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = dataset + self.dataset = dataset[self.split] self.dataset_length = len(self.dataset) self.gt_token_sequences = []