restore donut

This commit is contained in:
mkozlowskiAzimuthe 2023-01-25 17:44:25 +01:00
parent 1c22eaabf9
commit ab051e21b1
2 changed files with 4 additions and 8 deletions

View File

@ -34,12 +34,8 @@ def main(config, hug_token):
added_tokens = []
dataset = load_dataset(config.dataset_path, split='train')
validation_dataset = dataset.take(100)
train_dataset = dataset.skip(10000)
train_dataset = DonutDataset(
train_dataset,
config.dataset_path,
processor=processor,
model=model,
max_length=config.max_length,
@ -51,7 +47,7 @@ def main(config, hug_token):
)
val_dataset = DonutDataset(
validation_dataset,
config.dataset_path,
processor=processor,
model=model,
max_length=config.max_length,

View File

@ -24,7 +24,7 @@ class DonutDataset(Dataset):
def __init__(
self,
dataset: Dataset,
dataset_name_or_path: str,
max_length: int,
processor: DonutProcessor,
model: VisionEncoderDecoderModel,
@ -47,7 +47,7 @@ class DonutDataset(Dataset):
self.sort_json_key = sort_json_key
self.added_tokens = added_tokens
self.dataset = dataset
self.dataset = load_dataset(dataset_name_or_path, split=self.split)
self.dataset_length = len(self.dataset)
self.gt_token_sequences = []