torch format for dataset

This commit is contained in:
Michał Kozłowski 2023-01-24 18:18:50 +01:00
parent 5fca831ccd
commit 0466f8f60e

View File

@ -48,7 +48,7 @@ class DonutDataset(Dataset):
self.added_tokens = added_tokens
self.dataset = dataset
self.dataset_length = len(self.dataset)
self.dataset_length = len(self.dataset.with_format("torch"))
self.gt_token_sequences = []
for sample in self.dataset: