diff --git a/train_stream.py b/train_stream.py index 5aed52a..21eb5ab 100644 --- a/train_stream.py +++ b/train_stream.py @@ -34,68 +34,72 @@ def main(config, hug_token): added_tokens = [] - train_dataset = DonutDataset( - config.dataset_path, - processor=processor, - model=model, - max_length=config.max_length, - split="train", - task_start_token="", - prompt_end_token="", - added_tokens=added_tokens, - sort_json_key=False, # cord dataset is preprocessed, so no need for this - ) + dataset = load_dataset(config.dataset_path) + dataset.train_test_split(test_size=0.1) + print(dataset) - val_dataset = DonutDataset( - config.dataset_path, - processor=processor, - model=model, - max_length=config.max_length, - split="validation", - task_start_token="", - prompt_end_token="", - added_tokens=added_tokens, - sort_json_key=False, # cord dataset is preprocessed, so no need for this - ) + # train_dataset = DonutDataset( + # dataset, + # processor=processor, + # model=model, + # max_length=config.max_length, + # split="train", + # task_start_token="", + # prompt_end_token="", + # added_tokens=added_tokens, + # sort_json_key=False, # cord dataset is preprocessed, so no need for this + # ) - model.config.pad_token_id = processor.tokenizer.pad_token_id - model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] + # val_dataset = DonutDataset( + # dataset, + # processor=processor, + # model=model, + # max_length=config.max_length, + # split="validation", + # task_start_token="", + # prompt_end_token="", + # added_tokens=added_tokens, + # sort_json_key=False, # cord dataset is preprocessed, so no need for this + # ) - train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) - val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) + # model.config.pad_token_id = processor.tokenizer.pad_token_id + # model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] - login(hug_token, True) + # train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) + # val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) - model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) + # login(hug_token, True) + + # model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) - wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) + # wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) - checkpoint_callback = ModelCheckpoint( - monitor="val_metric", - dirpath=config.checkpoint_path, - filename="artifacts", - save_top_k=1, - save_last=False, - mode="min", - ) + # checkpoint_callback = ModelCheckpoint( + # monitor="val_metric", + # dirpath=config.checkpoint_path, + # filename="artifacts", + # save_top_k=1, + # save_last=False, + # mode="min", + # ) - custom_ckpt = CustomCheckpointIO() + # custom_ckpt = CustomCheckpointIO() - trainer = pl.Trainer( - accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu - devices=1, - max_epochs=config.train_config.max_epochs, - val_check_interval=config.train_config.val_check_interval, - check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, - gradient_clip_val=config.train_config.gradient_clip_val, - precision=16, # we'll use mixed precision - plugins=custom_ckpt, - num_sanity_val_steps=0, - logger=wandb_logger, - callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], - ) + # trainer = pl.Trainer( + # accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu + # devices=1, + # max_epochs=config.train_config.max_epochs, + # val_check_interval=config.train_config.val_check_interval, + # check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, + # gradient_clip_val=config.train_config.gradient_clip_val, + # precision=16, # we'll use mixed precision + # plugins=custom_ckpt, + # num_sanity_val_steps=0, + # logger=wandb_logger, + # callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], + # ) - trainer.fit(model_module) + # trainer.fit(model_module) if __name__ == "__main__": diff --git a/utils/donut_dataset_stream.py b/utils/donut_dataset_stream.py index e10a0fa..8abc1f7 100644 --- a/utils/donut_dataset_stream.py +++ b/utils/donut_dataset_stream.py @@ -24,7 +24,7 @@ class DonutDataset(Dataset): def __init__( self, - dataset_name_or_path: str, + dataset: Dataset, max_length: int, processor: DonutProcessor, model: VisionEncoderDecoderModel, @@ -47,8 +47,7 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = load_dataset(dataset_name_or_path, split=self.split, streaming=True).with_format("torch") - print(self.dataset) + self.dataset = dataset self.dataset_length = len(self.dataset) self.gt_token_sequences = []