diff --git a/train.py b/train.py index f37d48a..36e4953 100644 --- a/train.py +++ b/train.py @@ -35,8 +35,8 @@ def main(config, hug_token): added_tokens = [] dataset = load_dataset(config.dataset_path, split='train', streaming=True) - train_dataset = dataset.skip(100) validation_dataset = dataset.take(100) + train_dataset = dataset.skip(10000) train_dataset = DonutDataset( train_dataset,