diff --git a/train_stream.py b/train_stream.py index 95c8fd7..561828d 100644 --- a/train_stream.py +++ b/train_stream.py @@ -34,7 +34,7 @@ def main(config, hug_token): added_tokens = [] - dataset = load_dataset(config.dataset_path) + dataset = load_dataset(config.dataset_path, split="train[:80%]") dataset = dataset.train_test_split(test_size=0.1) train_dataset = DonutDataset(