diff --git a/code/src/training.py b/code/src/training.py index dc343f9..0fa7b41 100644 --- a/code/src/training.py +++ b/code/src/training.py @@ -7,7 +7,7 @@ import logging import numpy as np from pathlib import Path from safetensors.torch import load_file -from datasets import load_from_disk, DatasetDict +from datasets import load_from_disk, DatasetDict, load_dataset from transformers import ( AutoConfig, AutoTokenizer, @@ -58,7 +58,7 @@ def main(): if config['extra_embeddings']: wandb.save(current_dir / 'tree_starencoder.py') - dataset = load_from_disk(data_dir) + dataset = load_dataset("patrykbart/codeparrot-clean-no-comments-starencoder-small", split='train', num_proc=16, cache_dir=data_dir.parent) if config['num_samples'] > 0: dataset = dataset.select(range(config['num_samples'])) train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])