diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index 6319a0e..ed9474a 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -48,7 +48,7 @@ class DonutDataset(Dataset): self.added_tokens = added_tokens self.dataset = dataset - self.dataset_length = len(self.dataset) + self.dataset_length = len(self.dataset.with_format("torch")) self.gt_token_sequences = [] for sample in self.dataset: