version that works on CodeSearchNet

This commit is contained in:
System Administrator 2024-12-05 11:02:23 +01:00
parent b2d65059da
commit 1d8ce0c72f
9 changed files with 377 additions and 5313 deletions

View File

@ -39,4 +39,4 @@ distribution = true
[tool.pdm.scripts]
parse_dataset = {cmd = "src/parse_dataset.py"}
run_training = {cmd = "src/training.py"}
train = {cmd = "src/training.py"}

View File

@ -1,4 +1,8 @@
{
"extra_embeddings": true,
"run_name": "tree",
"data_dir": "./data/CodeSearchNet-parsed/python/",
"output_dir": "./outputs/tree",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,
@ -8,6 +12,6 @@
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 1000,
"train_size": 0.95,
"fp16": true
"fp16": true,
"logging_steps": 100
}

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +0,0 @@
import logging
from pathlib import Path
from typing import List
from huggingface_hub import list_repo_files, hf_hub_download
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def download_dataset(dataset_dir: Path) -> None:
if not dataset_dir.exists():
logger.info("Downloading the dataset...")
dataset_dir.mkdir(parents=True, exist_ok=True)
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
for file_name in files_to_download:
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
logger.info("Dataset downloaded successfully.")
if __name__ == '__main__':
dataset_dir = Path('data/the-stack-python')
download_dataset(dataset_dir)

View File

@ -1,11 +1,9 @@
import logging
import multiprocessing
import tree_sitter_python as tspython
from tqdm import tqdm
from pathlib import Path
from pprint import pprint
from typing import Dict, Any, List
from datasets import load_dataset, Dataset
from datasets import load_dataset, load_from_disk
from tree_sitter import Language, Parser
from transformers import AutoTokenizer
@ -82,7 +80,7 @@ def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: Auto
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
"""Process a batch of examples."""
contents = examples['content']
contents = examples['code']
processed_input_ids = []
processed_attention_mask = []
@ -120,97 +118,70 @@ def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
'node_texts': processed_node_texts,
}
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
"""Save dataset to disk in chunks to manage memory usage."""
num_chunks = (len(dataset) + chunk_size - 1) // chunk_size
# Create directory for chunks
output_dir = Path(output_path).parent
chunks_dir = output_dir / 'data'
chunks_dir.mkdir(exist_ok=True)
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
for i in pbar:
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(dataset))
# Select chunk from dataset
chunk_dataset = dataset.select(range(start_idx, end_idx))
if len(chunk_dataset) > 0: # Only process if chunk has data
# Save chunk using datasets native method
chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet'
chunk_dataset.to_parquet(str(chunk_path))
def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'the-stack-python'
output_dir = current_dir.parent / 'data' / 'processed-python'
input_dir = current_dir.parent / 'data' / 'CodeSearchNet' / 'python'
output_dir = current_dir.parent / 'data' / 'CodeSearchNet-parsed' / 'python'
output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['train']
# dataset = dataset.select(range(200_000)) # Limit dataset size for testing
original_dataset_size = len(dataset)
logger.info(f"Original dataset size: {original_dataset_size}")
# Initialize tokenizer and model from scratch
logger.info("Initializing tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
logger.info("Loaded CodeBERT tokenizer")
columns_to_remove = dataset.column_names
columns_to_remove.remove('content')
columns_to_remove.remove('size')
columns_to_remove.remove('avg_line_length')
columns_to_remove.remove('max_line_length')
columns_to_remove.remove('alphanum_fraction')
columns_to_remove.remove('max_stars_count')
logging.info(f"Columns to remove: {columns_to_remove}")
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
logger.info(f"Using {num_proc} processes for dataset processing")
logger.info("Processing dataset...")
processed_dataset = dataset.map(
process_batch,
fn_kwargs={'tokenizer': tokenizer},
batched=True,
remove_columns=columns_to_remove,
desc="Processing dataset",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
reduced_dataset_size = len(processed_dataset)
logger.info(f"Processed dataset size: {reduced_dataset_size}")
if reduced_dataset_size == 0:
logger.error("No valid examples found in dataset!")
return
logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...")
save_dataset_in_chunks(
processed_dataset,
str(output_dir / 'processed_dataset'),
chunk_size=100_000
)
# Add derived statistics
stats = {
'original_dataset_size': original_dataset_size,
'reduced_dataset_size': reduced_dataset_size,
'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100,
}
# Log stats with pprint
logger.info("Processing completed! Stats:")
pprint(stats)
for dataset_name in ['valid', 'test', 'train']:
logger.info(f"Processing dataset: {dataset_name}")
input_file = input_dir / f'{dataset_name}.jsonl'
dataset = load_dataset('json', data_files=str(input_file))['train']
original_dataset_size = len(dataset)
logger.info(f"Loaded dataset from {input_file} with {original_dataset_size} examples")
def tokenize_function(examples):
return tokenizer(
examples['code'],
padding='max_length',
truncation=True,
max_length=512,
return_special_tokens_mask=True
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
desc="Tokenizing",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = tokenized_dataset.map(
process_batch,
fn_kwargs={'tokenizer': tokenizer},
batched=True,
desc="Processing",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
if len(processed_dataset) == 0:
logger.error("No valid examples found in dataset!")
return
output_file = output_dir / f'{dataset_name}'
processed_dataset.save_to_disk(str(output_file))
logger.info(f"Saved processed dataset to {output_file} with {len(processed_dataset)} examples")
logger.info("All datasets processed")
if __name__ == "__main__":
main()

View File

@ -1,50 +0,0 @@
import logging
import multiprocessing
from pathlib import Path
from datasets import load_dataset
from parse_dataset import save_dataset_in_chunks
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'processed-python'
output_dir = current_dir.parent / 'data' / 'filtered-python'
output_dir.mkdir(parents=True, exist_ok=True)
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['data']
logging.info(f"Dataset:\n{dataset}")
logger.info("Filtering dataset by max_stars_count > 3...")
filtered_dataset = dataset.filter(
lambda batch: [example and example > 3 for example in batch['max_stars_count']],
num_proc=num_proc,
batched=True,
desc="Filtering dataset"
)
filtered_dataset_size = len(filtered_dataset)
logger.info(f"Filtered dataset size: {filtered_dataset_size}")
if filtered_dataset_size == 0:
logger.error("No examples found with max_stars_count > 3!")
return
logger.info(f"Saving {filtered_dataset_size} filtered examples...")
save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000)
logger.info("Filtering and saving completed!")
if __name__ == "__main__":
main()

View File

@ -1,306 +1,121 @@
import math
import torch
import torch.nn as nn
import wandb
import json
import logging
from pathlib import Path
from typing import Dict, Optional
from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM
from datasets import load_from_disk, DatasetDict
from transformers import (
RobertaConfig,
RobertaForMaskedLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from training_utils import *
from tree_model import TreeCodeBERTForPreTraining
MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom'
DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python'
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False):
super().__init__(config)
self.base_custom = base_custom
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
if base_custom:
initial_weight = math.log(1/2)
self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight))
else:
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f:
return json.load(f)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None and not self.base_custom:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
if self.base_custom:
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * seq_embeddings
)
else:
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
if self.base_custom:
embedding_weights = {
"token": weights_cpu[0].item(),
"sequential": weights_cpu[1].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
else:
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
"embedding_weights": embedding_weights,
"regularization_metrics": reg_metrics
}
def main() -> None:
def main():
# Setup paths
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name=MODEL)
config = load_config(current_dir / 'config.json')
data_dir = Path(config['data_dir'])
output_dir = Path(config['output_dir'])
set_deterministic_mode(config['seed'])
setup_wandb(config, model_name=MODEL, script_path=__file__)
set_seed(config['seed'])
device = setup_device()
# Initialize W&B
wandb.init(project='easy_training', config=config, name=config['run_name'])
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
# Upload the training files to W&B
wandb.save(__file__)
wandb.save(Path(__file__).parent / 'config.json')
if config['extra_embeddings']:
wandb.save(current_dir / 'tree_model.py')
# Create dataloaders with tree dataset
logger.info("Creating dataloaders...")
processed_data_dir = current_dir.parent / 'data' / DATASET
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
base_training=True if MODEL == 'base' else False
)
dataset = DatasetDict({
'train': load_from_disk(data_dir / 'train'),
'valid': load_from_disk(data_dir / 'valid'),
'test': load_from_disk(data_dir / 'test')
})
# Initialize model config
logger.info("Initializing model...")
# Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if config['extra_embeddings']:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove)
logger.info(f'Loaded dataset:\n{dataset}')
# Initialize model from scratch
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
# Update W&B
wandb.config.update({'model_config': model_config.__dict__})
# Initialize model
if MODEL == 'tree':
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512)
elif MODEL == 'base':
model = RobertaForMaskedLM(model_config)
elif MODEL == 'base-custom':
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True)
if config['extra_embeddings']:
model = TreeCodeBERTForPreTraining(model_config)
else:
raise ValueError(f"Invalid model type: {MODEL}")
model = model.to(device)
logging.info(f"Model initialized: {MODEL}")
# Model precision
if device.type == 'cuda' and config['fp16']:
model = model.half()
logging.info("Model set to half precision.")
for param in model.parameters():
logging.info(f"Parameter dtype: {param.dtype}")
break
model = RobertaForMaskedLM(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}')
# Training setup
logger.info("Setting up optimizer and scheduler...")
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
# Train and evaluate model
logger.info("Starting training...")
train_and_evaluate(
model=model,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
optimizer=optimizer,
scheduler=scheduler,
config=config,
device=device,
output_dir=output_dir
# Setup training arguments
training_args = TrainingArguments(
output_dir=str(output_dir),
per_device_train_batch_size=config['batch_size'],
per_device_eval_batch_size=config['batch_size'],
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
num_train_epochs=config['epochs'],
warmup_steps=config['warmup_steps'],
max_grad_norm=config['max_grad_norm'],
logging_steps=config['logging_steps'],
eval_steps=config['eval_every'],
save_steps=config['eval_every'],
eval_strategy='steps',
save_strategy='steps',
load_best_model_at_end=True,
report_to='wandb',
run_name=config['run_name'],
seed=config['seed'],
fp16=config['fp16'],
dataloader_num_workers=8,
gradient_checkpointing=True,
metric_for_best_model='eval_loss',
greater_is_better=False,
)
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'],
eval_dataset=dataset['valid'],
data_collator=DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config['mlm_probability']
),
)
# Train
logger.info('Starting training...')
trainer.train()
logger.info("Training completed!")
final_output_dir = output_dir / 'final-model'
model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
logger.info(f"Model saved to {final_output_dir}")
# Evaluate
logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(eval_dataset=dataset['test'])
logger.info(eval_results)
# Save final model
logger.info('Saving final model...')
trainer.save_model(output_dir / 'final-model')
tokenizer.save_pretrained(output_dir / 'final-model')
logger.info('Training completed!')
if __name__ == "__main__":
if __name__ == '__main__':
main()

View File

@ -1,355 +0,0 @@
import os
import json
import random
import datetime
import platform
import logging
import wandb
import numpy as np
import torch
from pathlib import Path
from typing import Dict, Any, Tuple, Optional
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
PreTrainedModel,
get_constant_schedule_with_warmup,
DataCollatorForLanguageModeling,
PreTrainedTokenizer,
RobertaForMaskedLM
)
from tqdm import tqdm
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def set_deterministic_mode(seed: int) -> None:
"""Enable deterministic mode for reproducibility."""
# Set Python random seed
random.seed(seed)
# Set NumPy random seed
np.random.seed(seed)
# Set PyTorch random seed
torch.manual_seed(seed)
# Set CUDA random seed
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Enable deterministic operations for PyTorch 2.x
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set environment variables for reproducibility
os.environ['PYTHONHASHSEED'] = f'{seed}'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
def get_device_info() -> Dict[str, Any]:
"""Get detailed information about the computing environment."""
info = {
'python_version': platform.python_version(),
'torch_version': torch.__version__,
'cuda_available': torch.cuda.is_available(),
'deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
'cudnn_deterministic': torch.backends.cudnn.deterministic,
'cudnn_benchmark': torch.backends.cudnn.benchmark,
}
if torch.cuda.is_available():
info.update({
'cuda_version': torch.version.cuda,
'gpu_name': torch.cuda.get_device_name(0),
'gpu_count': torch.cuda.device_count(),
})
return info
def set_seed(seed: int) -> None:
"""Set random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_wandb(
config: Dict[str, Any],
model_name: str = 'base',
script_path: Path = None,
device_info: Optional[Dict[str, Any]] = None
) -> None:
"""Initialize W&B logging with reproducibility information."""
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
# Add reproducibility settings to config
full_config = {
**config,
'deterministic_mode': True,
'device_info': device_info or get_device_info(),
}
wandb.init(
project='new-codebert-tree-training',
name=f'{model_name}_{curr_time}',
config=full_config
)
# Upload the training script
wandb.save(script_path)
wandb.save(Path(script_path).parent / 'training_utils.py')
logger.info(f'Saving script {script_path} to W&B')
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
"""Create output directories for model artifacts."""
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
output_dir: Path = current_dir.parent.parent / 'outputs' / f'{model_name}_{curr_time}'
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir
def load_config(config_file: Path) -> Dict[str, Any]:
"""Load training configuration from JSON file."""
with open(config_file, 'r') as f:
return json.load(f)
def setup_device() -> torch.device:
"""Setup and configure training device."""
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
logger.info(f'Using device: {device}')
if device.type == 'cuda':
logger.info(f'Device name: {torch.cuda.get_device_name()}')
torch.set_float32_matmul_precision('high')
return device
def create_optimizer_and_scheduler(
model: PreTrainedModel,
config: Dict[str, Any],
) -> Tuple[AdamW, Any]:
"""Create optimizer and learning rate scheduler."""
optimizer = AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
scheduler = get_constant_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
)
return optimizer, scheduler
def evaluate(
model: RobertaForMaskedLM,
dataloader: DataLoader,
device: torch.device
) -> Tuple[float, float]:
"""Evaluate the model on the validation set.
Returns:
Tuple of (average loss, accuracy on masked tokens)
"""
model.eval()
total_loss = 0
total_correct = 0
total_predictions = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc='Evaluating'):
batch = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
outputs = model(**batch)
# Get loss
total_loss += outputs['loss'].item()
# Calculate accuracy only on masked tokens
predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length]
labels = batch['labels'] # [batch_size, seq_length]
# Create mask for tokens we actually want to predict (ignore padding and unmasked tokens)
predict_mask = labels != -100 # -100 is the ignore index
# Calculate accuracy
correct = predictions[predict_mask] == labels[predict_mask]
total_correct += correct.sum().item()
total_predictions += predict_mask.sum().item()
avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
return avg_loss, accuracy
def train_and_evaluate(
model: RobertaForMaskedLM,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: AdamW,
scheduler: Any,
config: Dict[str, Any],
device: torch.device,
output_dir: Path,
) -> None:
"""Train and evaluate the model."""
model.train()
best_valid_loss = float('inf')
for epoch in range(config['epochs']):
total_loss = 0
# Training loop
with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
for step, batch in enumerate(pbar):
# Move batch to device
batch = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
# Forward pass
outputs = model(**batch)
loss = outputs['loss']
# Backward pass
loss.backward()
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Update metrics
total_loss += loss.item()
# Update progress bar
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
# Log to wandb if configured
log_values = {
'train_loss': loss.item(),
'grad_norm': grad_norm.item(),
'learning_rate': scheduler.get_last_lr()[0]
}
if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']:
log_values.update({
'token_weight': outputs['embedding_weights']['token'],
'tree_weight': outputs['embedding_weights']['tree'],
'seq_weight': outputs['embedding_weights']['sequential'],
})
elif 'embedding_weights' in outputs:
log_values.update({
'token_weight': outputs['embedding_weights']['token'],
'seq_weight': outputs['embedding_weights']['sequential'],
})
if 'regularization_metrics' in outputs:
log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()})
wandb.log(log_values)
# Evaluate periodically
if step % config['eval_every'] == 0 and step > 0:
valid_loss, valid_acc = evaluate(model, valid_dataloader, device)
# Log validation metrics
wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc})
# Print validation metrics
print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
# Save best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}')
model.train() # Resume training mode
def create_base_dataloaders(
processed_data_dir: Path,
tokenizer: PreTrainedTokenizer,
config: Dict[str, Any],
base_training=False,
) -> Tuple[DataLoader, DataLoader]:
"""Create dataloaders from pre-processed parquet data."""
# Load chunks using datasets library
chunks_dir = processed_data_dir / 'data'
dataset = load_dataset(
"parquet",
data_files=str(chunks_dir / "*.parquet"),
split="train"
)
contents = dataset['content'][:config['batch_size']]
with open("contents.txt", "w") as f:
for content in contents:
f.write(content)
f.write("\n\n\n")
f.write("#" * 80)
f.write("\n\n\n")
# Remove columns that are not needed
columns_to_remove = dataset.column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if not base_training:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove)
logging.info(f"Loaded dataset:\n{dataset}")
# Calculate split sizes
dataset_size = len(dataset)
train_size = int(config['train_size'] * dataset_size)
val_size = dataset_size - train_size
# Create splits
splits = dataset.train_test_split(
train_size=train_size,
test_size=val_size,
seed=config['seed']
)
train_dataset = splits['train']
valid_dataset = splits['test']
# Create data collator for MLM
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config['mlm_probability']
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=False, # We'll let the dataset handle shuffling
collate_fn=data_collator,
num_workers=0,
drop_last=True,
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=config['batch_size'],
shuffle=False,
collate_fn=data_collator,
num_workers=0,
drop_last=True,
)
return train_dataloader, valid_dataloader

208
code/src/tree_model.py Normal file
View File

@ -0,0 +1,208 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import RobertaConfig, RobertaForMaskedLM
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
wandb.log(
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
step=kwargs.get("global_step", None)
)
wandb.log(
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
step=kwargs.get("global_step", None)
)
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}