diff --git a/train_stream.py b/train_stream.py index c53100e..3079ffc 100644 --- a/train_stream.py +++ b/train_stream.py @@ -34,6 +34,9 @@ def main(config, hug_token): added_tokens = [] + dataset = load_dataset(config.dataset_path, split="train[:80%]") + dataset = dataset.train_test_split(test_size=0.1) + train_dataset_process = DonutDatasetStream( processor=processor, model=model, @@ -49,7 +52,7 @@ def main(config, hug_token): processor=processor, model=model, max_length=config.max_length, - split="validation", + split="test", task_start_token="", prompt_end_token="", added_tokens=added_tokens,