From dd5febad65d9075ddd8efab44a6e3ed90c4f5a7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Wed, 25 Jan 2023 21:29:15 +0100 Subject: [PATCH 1/5] testing --- train_stream.py | 108 ++++++++++++++++++---------------- utils/donut_dataset_stream.py | 5 +- 2 files changed, 58 insertions(+), 55 deletions(-) diff --git a/train_stream.py b/train_stream.py index 5aed52a..21eb5ab 100644 --- a/train_stream.py +++ b/train_stream.py @@ -34,68 +34,72 @@ def main(config, hug_token): added_tokens = [] - train_dataset = DonutDataset( - config.dataset_path, - processor=processor, - model=model, - max_length=config.max_length, - split="train", - task_start_token="", - prompt_end_token="", - added_tokens=added_tokens, - sort_json_key=False, # cord dataset is preprocessed, so no need for this - ) + dataset = load_dataset(config.dataset_path) + dataset.train_test_split(test_size=0.1) + print(dataset) - val_dataset = DonutDataset( - config.dataset_path, - processor=processor, - model=model, - max_length=config.max_length, - split="validation", - task_start_token="", - prompt_end_token="", - added_tokens=added_tokens, - sort_json_key=False, # cord dataset is preprocessed, so no need for this - ) + # train_dataset = DonutDataset( + # dataset, + # processor=processor, + # model=model, + # max_length=config.max_length, + # split="train", + # task_start_token="", + # prompt_end_token="", + # added_tokens=added_tokens, + # sort_json_key=False, # cord dataset is preprocessed, so no need for this + # ) - model.config.pad_token_id = processor.tokenizer.pad_token_id - model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] + # val_dataset = DonutDataset( + # dataset, + # processor=processor, + # model=model, + # max_length=config.max_length, + # split="validation", + # task_start_token="", + # prompt_end_token="", + # added_tokens=added_tokens, + # sort_json_key=False, # cord dataset is preprocessed, so no need for this + # ) - train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) - val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) + # model.config.pad_token_id = processor.tokenizer.pad_token_id + # model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] - login(hug_token, True) + # train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) + # val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) - model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) + # login(hug_token, True) + + # model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) - wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) + # wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) - checkpoint_callback = ModelCheckpoint( - monitor="val_metric", - dirpath=config.checkpoint_path, - filename="artifacts", - save_top_k=1, - save_last=False, - mode="min", - ) + # checkpoint_callback = ModelCheckpoint( + # monitor="val_metric", + # dirpath=config.checkpoint_path, + # filename="artifacts", + # save_top_k=1, + # save_last=False, + # mode="min", + # ) - custom_ckpt = CustomCheckpointIO() + # custom_ckpt = CustomCheckpointIO() - trainer = pl.Trainer( - accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu - devices=1, - max_epochs=config.train_config.max_epochs, - val_check_interval=config.train_config.val_check_interval, - check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, - gradient_clip_val=config.train_config.gradient_clip_val, - precision=16, # we'll use mixed precision - plugins=custom_ckpt, - num_sanity_val_steps=0, - logger=wandb_logger, - callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], - ) + # trainer = pl.Trainer( + # accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu + # devices=1, + # max_epochs=config.train_config.max_epochs, + # val_check_interval=config.train_config.val_check_interval, + # check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, + # gradient_clip_val=config.train_config.gradient_clip_val, + # precision=16, # we'll use mixed precision + # plugins=custom_ckpt, + # num_sanity_val_steps=0, + # logger=wandb_logger, + # callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], + # ) - trainer.fit(model_module) + # trainer.fit(model_module) if __name__ == "__main__": diff --git a/utils/donut_dataset_stream.py b/utils/donut_dataset_stream.py index e10a0fa..8abc1f7 100644 --- a/utils/donut_dataset_stream.py +++ b/utils/donut_dataset_stream.py @@ -24,7 +24,7 @@ class DonutDataset(Dataset): def __init__( self, - dataset_name_or_path: str, + dataset: Dataset, max_length: int, processor: DonutProcessor, model: VisionEncoderDecoderModel, @@ -47,8 +47,7 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = load_dataset(dataset_name_or_path, split=self.split, streaming=True).with_format("torch") - print(self.dataset) + self.dataset = dataset self.dataset_length = len(self.dataset) self.gt_token_sequences = [] From ecce4427a5b0865263e2ebe5a9afcca115cf5d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Wed, 25 Jan 2023 21:43:13 +0100 Subject: [PATCH 2/5] teesting train2 2 --- train_stream.py | 107 +++++++++++++++++----------------- utils/donut_dataset_stream.py | 2 +- 2 files changed, 54 insertions(+), 55 deletions(-) diff --git a/train_stream.py b/train_stream.py index 21eb5ab..95c8fd7 100644 --- a/train_stream.py +++ b/train_stream.py @@ -35,71 +35,70 @@ def main(config, hug_token): added_tokens = [] dataset = load_dataset(config.dataset_path) - dataset.train_test_split(test_size=0.1) - print(dataset) + dataset = dataset.train_test_split(test_size=0.1) - # train_dataset = DonutDataset( - # dataset, - # processor=processor, - # model=model, - # max_length=config.max_length, - # split="train", - # task_start_token="", - # prompt_end_token="", - # added_tokens=added_tokens, - # sort_json_key=False, # cord dataset is preprocessed, so no need for this - # ) + train_dataset = DonutDataset( + dataset, + processor=processor, + model=model, + max_length=config.max_length, + split="train", + task_start_token="", + prompt_end_token="", + added_tokens=added_tokens, + sort_json_key=False, # cord dataset is preprocessed, so no need for this + ) - # val_dataset = DonutDataset( - # dataset, - # processor=processor, - # model=model, - # max_length=config.max_length, - # split="validation", - # task_start_token="", - # prompt_end_token="", - # added_tokens=added_tokens, - # sort_json_key=False, # cord dataset is preprocessed, so no need for this - # ) + val_dataset = DonutDataset( + dataset, + processor=processor, + model=model, + max_length=config.max_length, + split="test", + task_start_token="", + prompt_end_token="", + added_tokens=added_tokens, + sort_json_key=False, # cord dataset is preprocessed, so no need for this + ) - # model.config.pad_token_id = processor.tokenizer.pad_token_id - # model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] + model.config.pad_token_id = processor.tokenizer.pad_token_id + model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] - # train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) - # val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) + train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) + val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) - # login(hug_token, True) + login(hug_token, True) - # model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) + model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader) - # wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) + wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name) - # checkpoint_callback = ModelCheckpoint( - # monitor="val_metric", - # dirpath=config.checkpoint_path, - # filename="artifacts", - # save_top_k=1, - # save_last=False, - # mode="min", - # ) + checkpoint_callback = ModelCheckpoint( + monitor="val_metric", + dirpath=config.checkpoint_path, + filename="artifacts", + save_top_k=1, + save_last=False, + mode="min", + ) - # custom_ckpt = CustomCheckpointIO() + custom_ckpt = CustomCheckpointIO() - # trainer = pl.Trainer( - # accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu - # devices=1, - # max_epochs=config.train_config.max_epochs, - # val_check_interval=config.train_config.val_check_interval, - # check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, - # gradient_clip_val=config.train_config.gradient_clip_val, - # precision=16, # we'll use mixed precision - # plugins=custom_ckpt, - # num_sanity_val_steps=0, - # logger=wandb_logger, - # callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], - # ) + trainer = pl.Trainer( + accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu + devices=1, + max_epochs=config.train_config.max_epochs, + val_check_interval=config.train_config.val_check_interval, + check_val_every_n_epoch=config.train_config.check_val_every_n_epoch, + gradient_clip_val=config.train_config.gradient_clip_val, + precision=16, # we'll use mixed precision + plugins=custom_ckpt, + num_sanity_val_steps=0, + logger=wandb_logger, + callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback], + ) - # trainer.fit(model_module) + trainer.fit(model_module) if __name__ == "__main__": diff --git a/utils/donut_dataset_stream.py b/utils/donut_dataset_stream.py index 8abc1f7..46ddaaa 100644 --- a/utils/donut_dataset_stream.py +++ b/utils/donut_dataset_stream.py @@ -47,7 +47,7 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = dataset + self.dataset = dataset[self.split] self.dataset_length = len(self.dataset) self.gt_token_sequences = [] From c4fec90d135289715c379704dc4a3696a4c65ce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Wed, 25 Jan 2023 21:45:22 +0100 Subject: [PATCH 3/5] fix len --- utils/donut_dataset_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/donut_dataset_stream.py b/utils/donut_dataset_stream.py index 46ddaaa..27b7971 100644 --- a/utils/donut_dataset_stream.py +++ b/utils/donut_dataset_stream.py @@ -116,8 +116,8 @@ class DonutDataset(Dataset): self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer)) self.added_tokens.extend(list_of_tokens) - # def __len__(self) -> int: - # return self.dataset_length + def __len__(self) -> int: + return self.dataset_length def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ From a436a38ddc647329a9a4d1b57095007fee0a5c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Wed, 25 Jan 2023 21:55:57 +0100 Subject: [PATCH 4/5] config --- config-train.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config-train.yaml b/config-train.yaml index 67f01c4..19917ef 100644 --- a/config-train.yaml +++ b/config-train.yaml @@ -1,6 +1,6 @@ dataset_path: "Zombely/wikisource-small" -pretrained_model_path: "Zombely/plwiki-proto-fine-tuned-v3.2" -start_model_path: "Zombely/plwiki-proto-fine-tuned-v3.2" +pretrained_model_path: "Zombely/plwiki-fine-tuned-v4" +start_model_path: "Zombely/plwiki-fine-tuned-v4" output_model_path: "Zombely/pl-donut" wandb_test_name: "wikisource-small" checkpoint_path: "./checkpoint" From dda04a1d5cf9f04144e57976bd5a0300b2c9688f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Wed, 25 Jan 2023 22:18:39 +0100 Subject: [PATCH 5/5] 80% split --- train_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_stream.py b/train_stream.py index 95c8fd7..561828d 100644 --- a/train_stream.py +++ b/train_stream.py @@ -34,7 +34,7 @@ def main(config, hug_token): added_tokens = [] - dataset = load_dataset(config.dataset_path) + dataset = load_dataset(config.dataset_path, split="train[:80%]") dataset = dataset.train_test_split(test_size=0.1) train_dataset = DonutDataset(