codeBERT training on codeparrot-clean

This commit is contained in:
Patryk Bartkowiak 2024-12-06 12:10:03 +00:00
parent 1d8ce0c72f
commit edadd3ee02
4 changed files with 90 additions and 72 deletions

View File

@ -10,15 +10,15 @@ dependencies = [
"torch==2.5.1",
"tqdm==4.66.5",
"tree-sitter==0.23.1",
"transformers[torch]>=4.46.3",
"transformers[torch]==4.46.3",
"datasets==3.0.1",
"huggingface-hub==0.26.0",
"matplotlib==3.9.2",
"scikit-learn==1.5.2",
"seaborn==0.13.2",
"tree-sitter-python==0.23.4",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.5",
"ipykernel==6.29.5",
"ipywidgets==8.1.5",
]
requires-python = "==3.11.*"
readme = "README.md"

View File

@ -1,17 +1,20 @@
{
"extra_embeddings": true,
"run_name": "tree",
"data_dir": "./data/CodeSearchNet-parsed/python/",
"output_dir": "./outputs/tree",
"extra_embeddings": false,
"run_name": "original",
"data_dir": "./data/codeparrot-clean-parsed/",
"output_dir": "./outputs/original",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,
"epochs": 5,
"eval_every": 1000,
"epochs": 1,
"eval_every": 5000,
"learning_rate": 5e-4,
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 1000,
"fp16": true,
"logging_steps": 100
"logging_steps": 100,
"valid_size": 0.05,
"test_size": 0.05,
"num_samples": -1
}

View File

@ -10,7 +10,6 @@ from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
@ -78,9 +77,9 @@ def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: Auto
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
def process_batch(examples, tokenizer, code_column) -> Dict[str, List[Any]]:
"""Process a batch of examples."""
contents = examples['code']
contents = examples[code_column]
processed_input_ids = []
processed_attention_mask = []
@ -119,9 +118,11 @@ def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
}
def main():
code_column = 'content'
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'CodeSearchNet' / 'python'
output_dir = current_dir.parent / 'data' / 'CodeSearchNet-parsed' / 'python'
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed'
output_dir.mkdir(parents=True, exist_ok=True)
# Initialize tokenizer and model from scratch
@ -132,56 +133,51 @@ def main():
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
for dataset_name in ['valid', 'test', 'train']:
logger.info(f"Processing dataset: {dataset_name}")
dataset = load_dataset("codeparrot/codeparrot-clean", cache_dir=str(input_dir), split='train')
original_dataset_size = len(dataset)
logger.info(f"Loaded dataset from {input_dir} with {original_dataset_size} examples")
input_file = input_dir / f'{dataset_name}.jsonl'
dataset = load_dataset('json', data_files=str(input_file))['train']
original_dataset_size = len(dataset)
logger.info(f"Loaded dataset from {input_file} with {original_dataset_size} examples")
def tokenize_function(examples):
return tokenizer(
examples['code'],
padding='max_length',
truncation=True,
max_length=512,
return_special_tokens_mask=True
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
desc="Tokenizing",
num_proc=num_proc,
load_from_cache_file=False
def tokenize_function(examples):
return tokenizer(
examples[code_column],
padding='max_length',
truncation=True,
max_length=512,
return_special_tokens_mask=True
)
processed_dataset = tokenized_dataset.map(
process_batch,
fn_kwargs={'tokenizer': tokenizer},
batched=True,
desc="Processing",
num_proc=num_proc,
load_from_cache_file=False
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
desc="Tokenizing",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
processed_dataset = tokenized_dataset.map(
process_batch,
fn_kwargs={'tokenizer': tokenizer, 'code_column': code_column},
batched=True,
desc="Processing",
num_proc=num_proc,
load_from_cache_file=False
)
if len(processed_dataset) == 0:
logger.error("No valid examples found in dataset!")
return
processed_dataset = processed_dataset.filter(
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
output_file = output_dir / f'{dataset_name}'
processed_dataset.save_to_disk(str(output_file))
logger.info(f"Saved processed dataset to {output_file} with {len(processed_dataset)} examples")
if len(processed_dataset) == 0:
logger.error("No valid examples found in dataset!")
return
logger.info("All datasets processed")
processed_dataset.save_to_disk(str(output_dir))
logger.info(f"Saved processed dataset to {output_dir} with {len(processed_dataset)} examples")
logger.info("Dataset processed")
if __name__ == "__main__":
main()

View File

@ -14,7 +14,11 @@ from transformers import (
from tree_model import TreeCodeBERTForPreTraining
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def load_config(config_path: Path) -> dict:
@ -29,7 +33,7 @@ def main():
output_dir = Path(config['output_dir'])
# Initialize W&B
wandb.init(project='easy_training', config=config, name=config['run_name'])
wandb.init(project='codeparrot', config=config, name=config['run_name'])
# Upload the training files to W&B
wandb.save(__file__)
@ -37,11 +41,25 @@ def main():
if config['extra_embeddings']:
wandb.save(current_dir / 'tree_model.py')
dataset = DatasetDict({
'train': load_from_disk(data_dir / 'train'),
'valid': load_from_disk(data_dir / 'valid'),
'test': load_from_disk(data_dir / 'test')
})
if 'CodeSearchNet' in config['data_dir']:
dataset = DatasetDict({
'train': load_from_disk(data_dir / 'train'),
'valid': load_from_disk(data_dir / 'valid'),
'test': load_from_disk(data_dir / 'test')
})
else:
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split(test_size=config['valid_size'] / (config['test_size'] + config['valid_size']))
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
# Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
@ -107,8 +125,9 @@ def main():
# Evaluate
logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(eval_dataset=dataset['test'])
logger.info(eval_results)
eval_results = trainer.evaluate(dataset['test'])
wandb.log({'test_loss': eval_results['eval_loss']})
logger.info(f'Test loss: {eval_results["eval_loss"]}')
# Save final model
logger.info('Saving final model...')