codeBERT training on codeparrot-clean

This commit is contained in:
Patryk Bartkowiak 2024-12-06 12:10:03 +00:00
parent 1d8ce0c72f
commit edadd3ee02
4 changed files with 90 additions and 72 deletions

View File

@ -10,15 +10,15 @@ dependencies = [
"torch==2.5.1", "torch==2.5.1",
"tqdm==4.66.5", "tqdm==4.66.5",
"tree-sitter==0.23.1", "tree-sitter==0.23.1",
"transformers[torch]>=4.46.3", "transformers[torch]==4.46.3",
"datasets==3.0.1", "datasets==3.0.1",
"huggingface-hub==0.26.0", "huggingface-hub==0.26.0",
"matplotlib==3.9.2", "matplotlib==3.9.2",
"scikit-learn==1.5.2", "scikit-learn==1.5.2",
"seaborn==0.13.2", "seaborn==0.13.2",
"tree-sitter-python==0.23.4", "tree-sitter-python==0.23.4",
"ipykernel>=6.29.5", "ipykernel==6.29.5",
"ipywidgets>=8.1.5", "ipywidgets==8.1.5",
] ]
requires-python = "==3.11.*" requires-python = "==3.11.*"
readme = "README.md" readme = "README.md"

View File

@ -1,17 +1,20 @@
{ {
"extra_embeddings": true, "extra_embeddings": false,
"run_name": "tree", "run_name": "original",
"data_dir": "./data/CodeSearchNet-parsed/python/", "data_dir": "./data/codeparrot-clean-parsed/",
"output_dir": "./outputs/tree", "output_dir": "./outputs/original",
"seed": 420, "seed": 420,
"mlm_probability": 0.15, "mlm_probability": 0.15,
"batch_size": 32, "batch_size": 32,
"epochs": 5, "epochs": 1,
"eval_every": 1000, "eval_every": 5000,
"learning_rate": 5e-4, "learning_rate": 5e-4,
"weight_decay": 0.1, "weight_decay": 0.1,
"max_grad_norm": 1.0, "max_grad_norm": 1.0,
"warmup_steps": 1000, "warmup_steps": 1000,
"fp16": true, "fp16": true,
"logging_steps": 100 "logging_steps": 100,
"valid_size": 0.05,
"test_size": 0.05,
"num_samples": -1
} }

View File

@ -10,7 +10,6 @@ from transformers import AutoTokenizer
import warnings import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning) warnings.filterwarnings("ignore", category=SyntaxWarning)
# Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format='%(asctime)s - %(levelname)s - %(message)s',
@ -78,9 +77,9 @@ def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: Auto
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]: def process_batch(examples, tokenizer, code_column) -> Dict[str, List[Any]]:
"""Process a batch of examples.""" """Process a batch of examples."""
contents = examples['code'] contents = examples[code_column]
processed_input_ids = [] processed_input_ids = []
processed_attention_mask = [] processed_attention_mask = []
@ -119,9 +118,11 @@ def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
} }
def main(): def main():
code_column = 'content'
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'CodeSearchNet' / 'python' input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
output_dir = current_dir.parent / 'data' / 'CodeSearchNet-parsed' / 'python' output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed'
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Initialize tokenizer and model from scratch # Initialize tokenizer and model from scratch
@ -132,56 +133,51 @@ def main():
num_proc = min(multiprocessing.cpu_count() - 1, 32) num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing") logger.info(f"Using {num_proc} processes for dataset processing")
for dataset_name in ['valid', 'test', 'train']: dataset = load_dataset("codeparrot/codeparrot-clean", cache_dir=str(input_dir), split='train')
logger.info(f"Processing dataset: {dataset_name}") original_dataset_size = len(dataset)
logger.info(f"Loaded dataset from {input_dir} with {original_dataset_size} examples")
input_file = input_dir / f'{dataset_name}.jsonl'
dataset = load_dataset('json', data_files=str(input_file))['train'] def tokenize_function(examples):
original_dataset_size = len(dataset) return tokenizer(
logger.info(f"Loaded dataset from {input_file} with {original_dataset_size} examples") examples[code_column],
padding='max_length',
def tokenize_function(examples): truncation=True,
return tokenizer( max_length=512,
examples['code'], return_special_tokens_mask=True
padding='max_length',
truncation=True,
max_length=512,
return_special_tokens_mask=True
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
desc="Tokenizing",
num_proc=num_proc,
load_from_cache_file=False
) )
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
desc="Tokenizing",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = tokenized_dataset.map(
process_batch,
fn_kwargs={'tokenizer': tokenizer, 'code_column': code_column},
batched=True,
desc="Processing",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
if len(processed_dataset) == 0:
logger.error("No valid examples found in dataset!")
return
processed_dataset.save_to_disk(str(output_dir))
logger.info(f"Saved processed dataset to {output_dir} with {len(processed_dataset)} examples")
processed_dataset = tokenized_dataset.map( logger.info("Dataset processed")
process_batch,
fn_kwargs={'tokenizer': tokenizer},
batched=True,
desc="Processing",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
if len(processed_dataset) == 0:
logger.error("No valid examples found in dataset!")
return
output_file = output_dir / f'{dataset_name}'
processed_dataset.save_to_disk(str(output_file))
logger.info(f"Saved processed dataset to {output_file} with {len(processed_dataset)} examples")
logger.info("All datasets processed")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -14,7 +14,11 @@ from transformers import (
from tree_model import TreeCodeBERTForPreTraining from tree_model import TreeCodeBERTForPreTraining
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_config(config_path: Path) -> dict: def load_config(config_path: Path) -> dict:
@ -29,7 +33,7 @@ def main():
output_dir = Path(config['output_dir']) output_dir = Path(config['output_dir'])
# Initialize W&B # Initialize W&B
wandb.init(project='easy_training', config=config, name=config['run_name']) wandb.init(project='codeparrot', config=config, name=config['run_name'])
# Upload the training files to W&B # Upload the training files to W&B
wandb.save(__file__) wandb.save(__file__)
@ -37,11 +41,25 @@ def main():
if config['extra_embeddings']: if config['extra_embeddings']:
wandb.save(current_dir / 'tree_model.py') wandb.save(current_dir / 'tree_model.py')
dataset = DatasetDict({ if 'CodeSearchNet' in config['data_dir']:
'train': load_from_disk(data_dir / 'train'), dataset = DatasetDict({
'valid': load_from_disk(data_dir / 'valid'), 'train': load_from_disk(data_dir / 'train'),
'test': load_from_disk(data_dir / 'test') 'valid': load_from_disk(data_dir / 'valid'),
}) 'test': load_from_disk(data_dir / 'test')
})
else:
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split(test_size=config['valid_size'] / (config['test_size'] + config['valid_size']))
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
# Continue with the rest of processing # Continue with the rest of processing
columns_to_remove = dataset['train'].column_names columns_to_remove = dataset['train'].column_names
@ -107,8 +125,9 @@ def main():
# Evaluate # Evaluate
logger.info('Evaluating on test set...') logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(eval_dataset=dataset['test']) eval_results = trainer.evaluate(dataset['test'])
logger.info(eval_results) wandb.log({'test_loss': eval_results['eval_loss']})
logger.info(f'Test loss: {eval_results["eval_loss"]}')
# Save final model # Save final model
logger.info('Saving final model...') logger.info('Saving final model...')