codeBERT training on codeparrot-clean
This commit is contained in:
parent
1d8ce0c72f
commit
edadd3ee02
@ -10,15 +10,15 @@ dependencies = [
|
||||
"torch==2.5.1",
|
||||
"tqdm==4.66.5",
|
||||
"tree-sitter==0.23.1",
|
||||
"transformers[torch]>=4.46.3",
|
||||
"transformers[torch]==4.46.3",
|
||||
"datasets==3.0.1",
|
||||
"huggingface-hub==0.26.0",
|
||||
"matplotlib==3.9.2",
|
||||
"scikit-learn==1.5.2",
|
||||
"seaborn==0.13.2",
|
||||
"tree-sitter-python==0.23.4",
|
||||
"ipykernel>=6.29.5",
|
||||
"ipywidgets>=8.1.5",
|
||||
"ipykernel==6.29.5",
|
||||
"ipywidgets==8.1.5",
|
||||
]
|
||||
requires-python = "==3.11.*"
|
||||
readme = "README.md"
|
||||
|
@ -1,17 +1,20 @@
|
||||
{
|
||||
"extra_embeddings": true,
|
||||
"run_name": "tree",
|
||||
"data_dir": "./data/CodeSearchNet-parsed/python/",
|
||||
"output_dir": "./outputs/tree",
|
||||
"extra_embeddings": false,
|
||||
"run_name": "original",
|
||||
"data_dir": "./data/codeparrot-clean-parsed/",
|
||||
"output_dir": "./outputs/original",
|
||||
"seed": 420,
|
||||
"mlm_probability": 0.15,
|
||||
"batch_size": 32,
|
||||
"epochs": 5,
|
||||
"eval_every": 1000,
|
||||
"epochs": 1,
|
||||
"eval_every": 5000,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.1,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 1000,
|
||||
"fp16": true,
|
||||
"logging_steps": 100
|
||||
"logging_steps": 100,
|
||||
"valid_size": 0.05,
|
||||
"test_size": 0.05,
|
||||
"num_samples": -1
|
||||
}
|
@ -10,7 +10,6 @@ from transformers import AutoTokenizer
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
@ -78,9 +77,9 @@ def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: Auto
|
||||
|
||||
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
|
||||
|
||||
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
||||
def process_batch(examples, tokenizer, code_column) -> Dict[str, List[Any]]:
|
||||
"""Process a batch of examples."""
|
||||
contents = examples['code']
|
||||
contents = examples[code_column]
|
||||
|
||||
processed_input_ids = []
|
||||
processed_attention_mask = []
|
||||
@ -119,9 +118,11 @@ def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
||||
}
|
||||
|
||||
def main():
|
||||
code_column = 'content'
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'CodeSearchNet' / 'python'
|
||||
output_dir = current_dir.parent / 'data' / 'CodeSearchNet-parsed' / 'python'
|
||||
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
|
||||
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize tokenizer and model from scratch
|
||||
@ -132,56 +133,51 @@ def main():
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
|
||||
for dataset_name in ['valid', 'test', 'train']:
|
||||
logger.info(f"Processing dataset: {dataset_name}")
|
||||
|
||||
input_file = input_dir / f'{dataset_name}.jsonl'
|
||||
dataset = load_dataset('json', data_files=str(input_file))['train']
|
||||
original_dataset_size = len(dataset)
|
||||
logger.info(f"Loaded dataset from {input_file} with {original_dataset_size} examples")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples['code'],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_special_tokens_mask=True
|
||||
)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
desc="Tokenizing",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
dataset = load_dataset("codeparrot/codeparrot-clean", cache_dir=str(input_dir), split='train')
|
||||
original_dataset_size = len(dataset)
|
||||
logger.info(f"Loaded dataset from {input_dir} with {original_dataset_size} examples")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples[code_column],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_special_tokens_mask=True
|
||||
)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
desc="Tokenizing",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = tokenized_dataset.map(
|
||||
process_batch,
|
||||
fn_kwargs={'tokenizer': tokenizer, 'code_column': code_column},
|
||||
batched=True,
|
||||
desc="Processing",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = processed_dataset.filter(
|
||||
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
desc="Filtering invalid examples"
|
||||
)
|
||||
|
||||
if len(processed_dataset) == 0:
|
||||
logger.error("No valid examples found in dataset!")
|
||||
return
|
||||
|
||||
processed_dataset.save_to_disk(str(output_dir))
|
||||
logger.info(f"Saved processed dataset to {output_dir} with {len(processed_dataset)} examples")
|
||||
|
||||
processed_dataset = tokenized_dataset.map(
|
||||
process_batch,
|
||||
fn_kwargs={'tokenizer': tokenizer},
|
||||
batched=True,
|
||||
desc="Processing",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = processed_dataset.filter(
|
||||
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
desc="Filtering invalid examples"
|
||||
)
|
||||
|
||||
if len(processed_dataset) == 0:
|
||||
logger.error("No valid examples found in dataset!")
|
||||
return
|
||||
|
||||
output_file = output_dir / f'{dataset_name}'
|
||||
processed_dataset.save_to_disk(str(output_file))
|
||||
logger.info(f"Saved processed dataset to {output_file} with {len(processed_dataset)} examples")
|
||||
|
||||
logger.info("All datasets processed")
|
||||
logger.info("Dataset processed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -14,7 +14,11 @@ from transformers import (
|
||||
|
||||
from tree_model import TreeCodeBERTForPreTraining
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_config(config_path: Path) -> dict:
|
||||
@ -29,7 +33,7 @@ def main():
|
||||
output_dir = Path(config['output_dir'])
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project='easy_training', config=config, name=config['run_name'])
|
||||
wandb.init(project='codeparrot', config=config, name=config['run_name'])
|
||||
|
||||
# Upload the training files to W&B
|
||||
wandb.save(__file__)
|
||||
@ -37,11 +41,25 @@ def main():
|
||||
if config['extra_embeddings']:
|
||||
wandb.save(current_dir / 'tree_model.py')
|
||||
|
||||
dataset = DatasetDict({
|
||||
'train': load_from_disk(data_dir / 'train'),
|
||||
'valid': load_from_disk(data_dir / 'valid'),
|
||||
'test': load_from_disk(data_dir / 'test')
|
||||
})
|
||||
if 'CodeSearchNet' in config['data_dir']:
|
||||
dataset = DatasetDict({
|
||||
'train': load_from_disk(data_dir / 'train'),
|
||||
'valid': load_from_disk(data_dir / 'valid'),
|
||||
'test': load_from_disk(data_dir / 'test')
|
||||
})
|
||||
else:
|
||||
|
||||
dataset = load_from_disk(data_dir)
|
||||
if config['num_samples'] > 0:
|
||||
dataset = dataset.select(range(config['num_samples']))
|
||||
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
||||
test_valid = train_testvalid['test'].train_test_split(test_size=config['valid_size'] / (config['test_size'] + config['valid_size']))
|
||||
dataset = DatasetDict({
|
||||
'train': train_testvalid['train'],
|
||||
'test': test_valid['test'],
|
||||
'valid': test_valid['train'],
|
||||
})
|
||||
|
||||
|
||||
# Continue with the rest of processing
|
||||
columns_to_remove = dataset['train'].column_names
|
||||
@ -107,8 +125,9 @@ def main():
|
||||
|
||||
# Evaluate
|
||||
logger.info('Evaluating on test set...')
|
||||
eval_results = trainer.evaluate(eval_dataset=dataset['test'])
|
||||
logger.info(eval_results)
|
||||
eval_results = trainer.evaluate(dataset['test'])
|
||||
wandb.log({'test_loss': eval_results['eval_loss']})
|
||||
logger.info(f'Test loss: {eval_results["eval_loss"]}')
|
||||
|
||||
# Save final model
|
||||
logger.info('Saving final model...')
|
||||
|
Loading…
Reference in New Issue
Block a user