codeBERT training on codeparrot-clean
This commit is contained in:
parent
1d8ce0c72f
commit
edadd3ee02
@ -10,15 +10,15 @@ dependencies = [
|
|||||||
"torch==2.5.1",
|
"torch==2.5.1",
|
||||||
"tqdm==4.66.5",
|
"tqdm==4.66.5",
|
||||||
"tree-sitter==0.23.1",
|
"tree-sitter==0.23.1",
|
||||||
"transformers[torch]>=4.46.3",
|
"transformers[torch]==4.46.3",
|
||||||
"datasets==3.0.1",
|
"datasets==3.0.1",
|
||||||
"huggingface-hub==0.26.0",
|
"huggingface-hub==0.26.0",
|
||||||
"matplotlib==3.9.2",
|
"matplotlib==3.9.2",
|
||||||
"scikit-learn==1.5.2",
|
"scikit-learn==1.5.2",
|
||||||
"seaborn==0.13.2",
|
"seaborn==0.13.2",
|
||||||
"tree-sitter-python==0.23.4",
|
"tree-sitter-python==0.23.4",
|
||||||
"ipykernel>=6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"ipywidgets>=8.1.5",
|
"ipywidgets==8.1.5",
|
||||||
]
|
]
|
||||||
requires-python = "==3.11.*"
|
requires-python = "==3.11.*"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
{
|
{
|
||||||
"extra_embeddings": true,
|
"extra_embeddings": false,
|
||||||
"run_name": "tree",
|
"run_name": "original",
|
||||||
"data_dir": "./data/CodeSearchNet-parsed/python/",
|
"data_dir": "./data/codeparrot-clean-parsed/",
|
||||||
"output_dir": "./outputs/tree",
|
"output_dir": "./outputs/original",
|
||||||
"seed": 420,
|
"seed": 420,
|
||||||
"mlm_probability": 0.15,
|
"mlm_probability": 0.15,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"epochs": 5,
|
"epochs": 1,
|
||||||
"eval_every": 1000,
|
"eval_every": 5000,
|
||||||
"learning_rate": 5e-4,
|
"learning_rate": 5e-4,
|
||||||
"weight_decay": 0.1,
|
"weight_decay": 0.1,
|
||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
"warmup_steps": 1000,
|
"warmup_steps": 1000,
|
||||||
"fp16": true,
|
"fp16": true,
|
||||||
"logging_steps": 100
|
"logging_steps": 100,
|
||||||
|
"valid_size": 0.05,
|
||||||
|
"test_size": 0.05,
|
||||||
|
"num_samples": -1
|
||||||
}
|
}
|
@ -10,7 +10,6 @@ from transformers import AutoTokenizer
|
|||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
@ -78,9 +77,9 @@ def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: Auto
|
|||||||
|
|
||||||
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
|
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
|
||||||
|
|
||||||
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
def process_batch(examples, tokenizer, code_column) -> Dict[str, List[Any]]:
|
||||||
"""Process a batch of examples."""
|
"""Process a batch of examples."""
|
||||||
contents = examples['code']
|
contents = examples[code_column]
|
||||||
|
|
||||||
processed_input_ids = []
|
processed_input_ids = []
|
||||||
processed_attention_mask = []
|
processed_attention_mask = []
|
||||||
@ -119,9 +118,11 @@ def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
code_column = 'content'
|
||||||
|
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
input_dir = current_dir.parent / 'data' / 'CodeSearchNet' / 'python'
|
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
|
||||||
output_dir = current_dir.parent / 'data' / 'CodeSearchNet-parsed' / 'python'
|
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed'
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Initialize tokenizer and model from scratch
|
# Initialize tokenizer and model from scratch
|
||||||
@ -132,56 +133,51 @@ def main():
|
|||||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||||
|
|
||||||
for dataset_name in ['valid', 'test', 'train']:
|
dataset = load_dataset("codeparrot/codeparrot-clean", cache_dir=str(input_dir), split='train')
|
||||||
logger.info(f"Processing dataset: {dataset_name}")
|
original_dataset_size = len(dataset)
|
||||||
|
logger.info(f"Loaded dataset from {input_dir} with {original_dataset_size} examples")
|
||||||
|
|
||||||
input_file = input_dir / f'{dataset_name}.jsonl'
|
def tokenize_function(examples):
|
||||||
dataset = load_dataset('json', data_files=str(input_file))['train']
|
return tokenizer(
|
||||||
original_dataset_size = len(dataset)
|
examples[code_column],
|
||||||
logger.info(f"Loaded dataset from {input_file} with {original_dataset_size} examples")
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
def tokenize_function(examples):
|
max_length=512,
|
||||||
return tokenizer(
|
return_special_tokens_mask=True
|
||||||
examples['code'],
|
|
||||||
padding='max_length',
|
|
||||||
truncation=True,
|
|
||||||
max_length=512,
|
|
||||||
return_special_tokens_mask=True
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(
|
|
||||||
tokenize_function,
|
|
||||||
batched=True,
|
|
||||||
desc="Tokenizing",
|
|
||||||
num_proc=num_proc,
|
|
||||||
load_from_cache_file=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_dataset = tokenized_dataset.map(
|
tokenized_dataset = dataset.map(
|
||||||
process_batch,
|
tokenize_function,
|
||||||
fn_kwargs={'tokenizer': tokenizer},
|
batched=True,
|
||||||
batched=True,
|
desc="Tokenizing",
|
||||||
desc="Processing",
|
num_proc=num_proc,
|
||||||
num_proc=num_proc,
|
load_from_cache_file=False
|
||||||
load_from_cache_file=False
|
)
|
||||||
)
|
|
||||||
|
|
||||||
processed_dataset = processed_dataset.filter(
|
processed_dataset = tokenized_dataset.map(
|
||||||
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
process_batch,
|
||||||
batched=True,
|
fn_kwargs={'tokenizer': tokenizer, 'code_column': code_column},
|
||||||
num_proc=num_proc,
|
batched=True,
|
||||||
desc="Filtering invalid examples"
|
desc="Processing",
|
||||||
)
|
num_proc=num_proc,
|
||||||
|
load_from_cache_file=False
|
||||||
|
)
|
||||||
|
|
||||||
if len(processed_dataset) == 0:
|
processed_dataset = processed_dataset.filter(
|
||||||
logger.error("No valid examples found in dataset!")
|
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||||
return
|
batched=True,
|
||||||
|
num_proc=num_proc,
|
||||||
|
desc="Filtering invalid examples"
|
||||||
|
)
|
||||||
|
|
||||||
output_file = output_dir / f'{dataset_name}'
|
if len(processed_dataset) == 0:
|
||||||
processed_dataset.save_to_disk(str(output_file))
|
logger.error("No valid examples found in dataset!")
|
||||||
logger.info(f"Saved processed dataset to {output_file} with {len(processed_dataset)} examples")
|
return
|
||||||
|
|
||||||
logger.info("All datasets processed")
|
processed_dataset.save_to_disk(str(output_dir))
|
||||||
|
logger.info(f"Saved processed dataset to {output_dir} with {len(processed_dataset)} examples")
|
||||||
|
|
||||||
|
logger.info("Dataset processed")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -14,7 +14,11 @@ from transformers import (
|
|||||||
|
|
||||||
from tree_model import TreeCodeBERTForPreTraining
|
from tree_model import TreeCodeBERTForPreTraining
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def load_config(config_path: Path) -> dict:
|
def load_config(config_path: Path) -> dict:
|
||||||
@ -29,7 +33,7 @@ def main():
|
|||||||
output_dir = Path(config['output_dir'])
|
output_dir = Path(config['output_dir'])
|
||||||
|
|
||||||
# Initialize W&B
|
# Initialize W&B
|
||||||
wandb.init(project='easy_training', config=config, name=config['run_name'])
|
wandb.init(project='codeparrot', config=config, name=config['run_name'])
|
||||||
|
|
||||||
# Upload the training files to W&B
|
# Upload the training files to W&B
|
||||||
wandb.save(__file__)
|
wandb.save(__file__)
|
||||||
@ -37,11 +41,25 @@ def main():
|
|||||||
if config['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
wandb.save(current_dir / 'tree_model.py')
|
wandb.save(current_dir / 'tree_model.py')
|
||||||
|
|
||||||
dataset = DatasetDict({
|
if 'CodeSearchNet' in config['data_dir']:
|
||||||
'train': load_from_disk(data_dir / 'train'),
|
dataset = DatasetDict({
|
||||||
'valid': load_from_disk(data_dir / 'valid'),
|
'train': load_from_disk(data_dir / 'train'),
|
||||||
'test': load_from_disk(data_dir / 'test')
|
'valid': load_from_disk(data_dir / 'valid'),
|
||||||
})
|
'test': load_from_disk(data_dir / 'test')
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
|
||||||
|
dataset = load_from_disk(data_dir)
|
||||||
|
if config['num_samples'] > 0:
|
||||||
|
dataset = dataset.select(range(config['num_samples']))
|
||||||
|
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
||||||
|
test_valid = train_testvalid['test'].train_test_split(test_size=config['valid_size'] / (config['test_size'] + config['valid_size']))
|
||||||
|
dataset = DatasetDict({
|
||||||
|
'train': train_testvalid['train'],
|
||||||
|
'test': test_valid['test'],
|
||||||
|
'valid': test_valid['train'],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
# Continue with the rest of processing
|
# Continue with the rest of processing
|
||||||
columns_to_remove = dataset['train'].column_names
|
columns_to_remove = dataset['train'].column_names
|
||||||
@ -107,8 +125,9 @@ def main():
|
|||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
logger.info('Evaluating on test set...')
|
logger.info('Evaluating on test set...')
|
||||||
eval_results = trainer.evaluate(eval_dataset=dataset['test'])
|
eval_results = trainer.evaluate(dataset['test'])
|
||||||
logger.info(eval_results)
|
wandb.log({'test_loss': eval_results['eval_loss']})
|
||||||
|
logger.info(f'Test loss: {eval_results["eval_loss"]}')
|
||||||
|
|
||||||
# Save final model
|
# Save final model
|
||||||
logger.info('Saving final model...')
|
logger.info('Saving final model...')
|
||||||
|
Loading…
Reference in New Issue
Block a user