prepared for run by prof Filip Gralinski (done)
This commit is contained in:
parent
76a89dc236
commit
3b93a7cc8a
@ -7,7 +7,7 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from datasets import load_from_disk, DatasetDict
|
from datasets import load_from_disk, DatasetDict, load_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -58,7 +58,7 @@ def main():
|
|||||||
if config['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
wandb.save(current_dir / 'tree_starencoder.py')
|
wandb.save(current_dir / 'tree_starencoder.py')
|
||||||
|
|
||||||
dataset = load_from_disk(data_dir)
|
dataset = load_dataset("patrykbart/codeparrot-clean-no-comments-starencoder-small", split='train', num_proc=16, cache_dir=data_dir.parent)
|
||||||
if config['num_samples'] > 0:
|
if config['num_samples'] > 0:
|
||||||
dataset = dataset.select(range(config['num_samples']))
|
dataset = dataset.select(range(config['num_samples']))
|
||||||
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
||||||
|
Loading…
Reference in New Issue
Block a user