This commit is contained in:
Michał Kozłowski 2023-01-24 18:20:49 +01:00
parent 0466f8f60e
commit b7373610dc

View File

@ -48,7 +48,7 @@ class DonutDataset(Dataset):
self.added_tokens = added_tokens self.added_tokens = added_tokens
self.dataset = dataset self.dataset = dataset
self.dataset_length = len(self.dataset.with_format("torch")) # self.dataset_length = len(self.dataset.with_format("torch"))
self.gt_token_sequences = [] self.gt_token_sequences = []
for sample in self.dataset: for sample in self.dataset:
@ -116,8 +116,8 @@ class DonutDataset(Dataset):
self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer)) self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer))
self.added_tokens.extend(list_of_tokens) self.added_tokens.extend(list_of_tokens)
def __len__(self) -> int: # def __len__(self) -> int:
return self.dataset_length # return self.dataset_length
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """