From 1c22eaabf95232686c9834555568462c8d266432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Koz=C5=82owski?= Date: Tue, 24 Jan 2023 18:23:25 +0100 Subject: [PATCH] back to normal --- train.py | 2 +- utils/donut_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 36e4953..8dc8d20 100644 --- a/train.py +++ b/train.py @@ -34,7 +34,7 @@ def main(config, hug_token): added_tokens = [] - dataset = load_dataset(config.dataset_path, split='train', streaming=True) + dataset = load_dataset(config.dataset_path, split='train') validation_dataset = dataset.take(100) train_dataset = dataset.skip(10000) diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index 3c4651c..8abc1f7 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -47,8 +47,8 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = dataset.with_format("torch") - # self.dataset_length = len(self.dataset.with_format("torch")) + self.dataset = dataset + self.dataset_length = len(self.dataset) self.gt_token_sequences = [] for sample in self.dataset: