diff --git a/train.py b/train.py index 8dc8d20..3dc2d91 100644 --- a/train.py +++ b/train.py @@ -34,12 +34,8 @@ def main(config, hug_token): added_tokens = [] - dataset = load_dataset(config.dataset_path, split='train') - validation_dataset = dataset.take(100) - train_dataset = dataset.skip(10000) - train_dataset = DonutDataset( - train_dataset, + config.dataset_path, processor=processor, model=model, max_length=config.max_length, @@ -51,7 +47,7 @@ def main(config, hug_token): ) val_dataset = DonutDataset( - validation_dataset, + config.dataset_path, processor=processor, model=model, max_length=config.max_length, diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index 8abc1f7..200070c 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -24,7 +24,7 @@ class DonutDataset(Dataset): def __init__( self, - dataset: Dataset, + dataset_name_or_path: str, max_length: int, processor: DonutProcessor, model: VisionEncoderDecoderModel, @@ -47,7 +47,7 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = dataset + self.dataset = load_dataset(dataset_name_or_path, split=self.split) self.dataset_length = len(self.dataset) self.gt_token_sequences = []