This commit is contained in:
mkozlowskiAzimuthe 2023-03-14 15:51:05 +01:00
commit 2f1176b3c0

View File

@ -34,6 +34,9 @@ def main(config, hug_token):
added_tokens = [] added_tokens = []
dataset = load_dataset(config.dataset_path, split="train[:80%]")
dataset = dataset.train_test_split(test_size=0.1)
train_dataset_process = DonutDatasetStream( train_dataset_process = DonutDatasetStream(
processor=processor, processor=processor,
model=model, model=model,
@ -49,7 +52,7 @@ def main(config, hug_token):
processor=processor, processor=processor,
model=model, model=model,
max_length=config.max_length, max_length=config.max_length,
split="validation", split="test",
task_start_token="<s_cord-v2>", task_start_token="<s_cord-v2>",
prompt_end_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
added_tokens=added_tokens, added_tokens=added_tokens,