diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index 4db3a1e..3c4651c 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -47,7 +47,7 @@ class DonutDataset(Dataset): self.sort_json_key = sort_json_key self.added_tokens = added_tokens - self.dataset = dataset + self.dataset = dataset.with_format("torch") # self.dataset_length = len(self.dataset.with_format("torch")) self.gt_token_sequences = []