diff --git a/utils/donut_dataset.py b/utils/donut_dataset.py index ed9474a..4db3a1e 100644 --- a/utils/donut_dataset.py +++ b/utils/donut_dataset.py @@ -48,7 +48,7 @@ class DonutDataset(Dataset): self.added_tokens = added_tokens self.dataset = dataset - self.dataset_length = len(self.dataset.with_format("torch")) + # self.dataset_length = len(self.dataset.with_format("torch")) self.gt_token_sequences = [] for sample in self.dataset: @@ -116,8 +116,8 @@ class DonutDataset(Dataset): self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer)) self.added_tokens.extend(list_of_tokens) - def __len__(self) -> int: - return self.dataset_length + # def __len__(self) -> int: + # return self.dataset_length def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """