added tree enhanced embeddings into training

This commit is contained in:
Patryk Bartkowiak 2024-11-05 14:53:32 +00:00
parent d4ccfaac44
commit 9299c56bb1
4 changed files with 523 additions and 1415 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,12 @@
{
"seed": 42,
"mlm_probability": 0.15,
"batch": 32,
"batch": 16,
"epochs": 1,
"eval_every": 10000,
"eval_every": 20000,
"learning_rate": 5e-4,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"warmup_steps": 10000
"warmup_steps": 20000,
"pause_on_instability": false
}

View File

@ -13,7 +13,7 @@ import torch
from torch import Tensor
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, disable_caching, DatasetDict
from datasets import load_dataset, DatasetDict
from huggingface_hub import list_repo_files, hf_hub_download
from transformers import (
RobertaForMaskedLM,
@ -61,14 +61,14 @@ def set_seed(seed: int) -> None:
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_wandb(config: Dict[str, Any]) -> None:
def setup_wandb(config: Dict[str, Any], exec_file: str = 'train_codebert_mlm.py') -> None:
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
wandb.init(project='codebert-training', name=curr_time, config=config)
wandb.init(project='codebert-training', name=f'tree_{curr_time}', config=config)
wandb.save('train_codebert_mlm.py')
def setup_directories(current_dir: Path) -> Path:
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
output_dir: Path = current_dir.parent.parent / 'outputs' / curr_time
output_dir: Path = current_dir.parent.parent / 'outputs' / f'tree_{curr_time}'
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir
@ -130,10 +130,11 @@ def create_dataloaders(
def setup_model_and_optimizer(
config: Dict[str, Any],
current_dir: Path
models_dir: Path
) -> Tuple[PreTrainedModel, AdamW]:
os.environ['HF_HOME'] = str(current_dir.parent / 'models')
model_config = RobertaConfig.from_pretrained('roberta-base')
if not models_dir.exists():
logger.info("Downloading the model...")
model_config = RobertaConfig.from_pretrained('roberta-base', cache_dir=models_dir)
model: PreTrainedModel = RobertaForMaskedLM(model_config)
model = torch.compile(model)
wandb.watch(model)
@ -212,8 +213,6 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo
return avg_loss, avg_acc
def main() -> None:
disable_caching()
current_dir: Path = Path(__file__).parent
output_dir: Path = setup_directories(current_dir)
config: Dict[str, Any] = load_config(current_dir / 'config.json')
@ -228,10 +227,16 @@ def main() -> None:
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
logger.info(f'Tokenizer loaded: {tokenizer}')
######################## Reproducing last training here ########################
# Remove first 186513 batches
dataset['train'] = dataset['train'].select(range(186_513 * 32, len(dataset['train'])))
################################################################################
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
model, optimizer = setup_model_and_optimizer(config, current_dir)
models_dir: Path = current_dir.parent / 'models' / 'roberta-base'
model, optimizer = setup_model_and_optimizer(config, models_dir)
num_training_steps: int = config['epochs'] * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(
@ -239,6 +244,16 @@ def main() -> None:
num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
)
######################## Reproducing last training here ########################
# Change opitmizer learning rate to 0.00021575814536340852
optimizer = AdamW(model.parameters(), lr=0.00021575814536340852, weight_decay=config['weight_decay'])
# Set warmup_steps to 0
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
# Load the best model weights
state_dict = torch.load('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt', weights_only=True, map_location=device)
model.load_state_dict(state_dict)
################################################################################
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)

View File

@ -0,0 +1,494 @@
import wandb
import sys
import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import AdamW
from dataclasses import dataclass
from tqdm import tqdm
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import ast
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import DatasetDict
from transformers import (
RobertaForMaskedLM,
RobertaConfig,
RobertaTokenizer,
get_linear_schedule_with_warmup,
PreTrainedTokenizer,
PreTrainedModel,
)
sys.setrecursionlimit(3000) # Increase recursion limit
# Import existing training functionality
from train_codebert_mlm import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, download_dataset, load_and_prepare_dataset, logger
)
@dataclass
class ASTNodeInfo:
"""Stores structural information about an AST node."""
node_type: str
start_token_idx: int
end_token_idx: int
depth: int
sibling_pos: int
parent_idx: int
def ast_collate_fn(batch):
"""Custom collate function with improved error handling."""
# Remove failed parses (where attention_mask is all zeros)
valid_batch = [
item for item in batch
if item['attention_mask'].sum() > 0 and len(item['ast_nodes']) > 0
]
if not valid_batch:
# Return minimal batch if no valid items
return {
'input_ids': torch.zeros((1, 512), dtype=torch.long),
'attention_mask': torch.zeros((1, 512), dtype=torch.long),
'labels': torch.zeros((1, 512), dtype=torch.long),
'ast_nodes': [[]]
}
# Stack tensors
input_ids = torch.stack([item['input_ids'] for item in valid_batch])
attention_mask = torch.stack([item['attention_mask'] for item in valid_batch])
labels = torch.stack([item['labels'] for item in valid_batch])
# Collect AST nodes
ast_nodes = [item['ast_nodes'] for item in valid_batch]
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'ast_nodes': ast_nodes
}
class TreePositionalEmbedding(nn.Module):
"""Generates tree-aware positional embeddings for code tokens."""
def __init__(self, d_model: int = 768, max_depth: int = 32):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.combine = nn.Linear(d_model * 2, d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
position = torch.arange(self.max_depth).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
(-math.log(10000.0) / self.d_model))
pe = torch.zeros(self.max_depth, self.d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
with torch.no_grad():
self.depth_embedding.weight.copy_(pe)
self.sibling_embedding.weight.copy_(pe)
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
"""Process batched input with corresponding AST nodes."""
batch_size, seq_len = input_ids.shape
device = input_ids.device
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
# Process each item in the batch
for batch_idx in range(batch_size):
ast_nodes = ast_nodes_batch[batch_idx]
# Process each position in the sequence
for i in range(seq_len):
containing_nodes = [
node for node in ast_nodes
if node.start_token_idx <= i < node.end_token_idx
]
if containing_nodes:
node = max(containing_nodes, key=lambda n: n.depth)
depth = min(node.depth, self.max_depth - 1)
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
return embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize weights with small random values around 0
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
self.beta = nn.Parameter(torch.randn(1) * 0.02)
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
# Add dropout for regularization
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
def get_normalized_weights(self) -> torch.Tensor:
"""
Compute softmax-normalized weights for embedding combination.
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
"""
weights = torch.stack([self.alpha, self.beta, self.gamma])
return torch.softmax(weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
# Move tensors to device
device = input_ids.device
# Get embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get normalized weights
norm_weights = self.get_normalized_weights()
# Combine embeddings based on presence of AST nodes
if ast_nodes is not None:
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
combined_embeddings = (
norm_weights[0] * token_embeddings +
norm_weights[1] * tree_embeddings +
norm_weights[2] * seq_embeddings
)
else:
# Redistribute tree weight to other components when no AST available
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
combined_embeddings = (
token_seq_weights[0] * token_embeddings +
token_seq_weights[1] * seq_embeddings
)
# Apply layer normalization and dropout
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
combined_embeddings = self.embedding_dropout(combined_embeddings)
combined_embeddings = self.final_layer_norm(combined_embeddings)
# Forward pass through transformer
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate loss if labels provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Get normalized weights for logging
norm_weights_cpu = norm_weights.detach().cpu()
return {
"loss": masked_lm_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions,
"embedding_weights": {
"token": norm_weights_cpu[0].item(),
"tree": norm_weights_cpu[1].item(),
"sequential": norm_weights_cpu[2].item()
}
}
class TreeEnhancedDataset(Dataset):
"""Dataset that processes code into tokens and AST nodes with improved error handling."""
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
self.dataset = dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.max_ast_size = 1000 # Limit maximum AST size to prevent memory issues
def __len__(self) -> int:
return len(self.dataset)
def create_dummy_data(self) -> Dict[str, Any]:
"""Create dummy data for invalid/problematic code samples."""
pad_token_id = self.tokenizer.pad_token_id or 0
dummy_tensor = torch.full((self.max_length,), pad_token_id, dtype=torch.long)
return {
'input_ids': dummy_tensor,
'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
'labels': dummy_tensor.clone(),
'ast_nodes': []
}
def __getitem__(self, idx: int) -> Dict[str, Any]:
try:
content: str = self.dataset[idx]['content']
# Skip extremely long files
if len(content) > 50000: # ~1000 lines
return self.create_dummy_data()
# Basic code validation
if not content.strip() or '\0' in content:
return self.create_dummy_data()
# Tokenize first
encoding = self.tokenizer(
content,
max_length=self.max_length,
truncation=True,
padding='max_length',
return_tensors='pt'
)
try:
# Parse AST with timeout
tree = ast.parse(content)
# Get AST nodes info
nodes_info = []
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
if len(nodes_info) >= self.max_ast_size:
return
current_idx = len(nodes_info)
if hasattr(node, 'lineno'):
nodes_info.append(ASTNodeInfo(
node_type=type(node).__name__,
start_token_idx=node.lineno,
end_token_idx=getattr(node, 'end_lineno', node.lineno),
depth=min(depth, 31), # Limit depth to prevent issues
sibling_pos=min(sibling_pos, 31),
parent_idx=parent_idx
))
# Only process children if we haven't hit the limit
if len(nodes_info) < self.max_ast_size:
for i, child in enumerate(ast.iter_child_nodes(node)):
visit_node(child, depth + 1, current_idx, i)
visit_node(tree, 0, -1, 0)
# If we hit the AST size limit, use dummy data
if len(nodes_info) >= self.max_ast_size:
return self.create_dummy_data()
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': encoding['input_ids'].squeeze(0).clone(),
'ast_nodes': nodes_info
}
except (SyntaxError, ValueError, RecursionError, TimeoutError, MemoryError):
return self.create_dummy_data()
except Exception as e:
return self.create_dummy_data()
def create_tree_dataloaders(
dataset: DatasetDict,
tokenizer: PreTrainedTokenizer,
config: Dict[str, Any],
device: torch.device
) -> Tuple[DataLoader, DataLoader]:
"""Create dataloaders with tree-enhanced datasets."""
train_dataset = TreeEnhancedDataset(dataset['train'], tokenizer, max_length=512)
valid_dataset = TreeEnhancedDataset(dataset['test'], tokenizer, max_length=512)
train_dataloader = DataLoader(
train_dataset,
batch_size=config['batch'],
shuffle=True,
collate_fn=ast_collate_fn,
generator=torch.Generator(device=device)
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=config['batch'],
shuffle=False,
collate_fn=ast_collate_fn,
generator=torch.Generator(device=device)
)
return train_dataloader, valid_dataloader
def train_and_evaluate(
model: PreTrainedModel,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: AdamW,
scheduler: Any,
config: Dict[str, Any],
output_dir: Path
) -> None:
"""Training loop with explicit tracking of alpha, beta, gamma weights."""
num_training_steps: int = config['epochs'] * len(train_dataloader)
best_valid_loss: float = float('inf')
with tqdm(total=num_training_steps, desc='Training') as pbar:
for epoch_idx in range(config['epochs']):
model.train()
for train_idx, train_batch in enumerate(train_dataloader):
outputs = model(**train_batch)
train_loss = outputs["loss"]
if train_loss is not None:
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Get current metrics
current_loss = train_loss.item()
weights = outputs["embedding_weights"]
# Update progress bar with all three weights
pbar.update(1)
pbar.set_postfix({
'loss': f"{current_loss:.3f}",
'α': f"{weights['token']:.2f}",
'β': f"{weights['tree']:.2f}",
'γ': f"{weights['sequential']:.2f}"
})
# Log all three weights separately
step = train_idx + len(train_dataloader) * epoch_idx
wandb.log({
'loss': current_loss,
'token_weight': weights['token'],
'tree_weight': weights['tree'],
'sequential_weight': weights['sequential'],
'step': step,
})
# Periodic evaluation
if train_idx != 0 and train_idx % config['eval_every'] == 0:
valid_loss, valid_acc = evaluate(model, valid_dataloader)
wandb.log({
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': step,
})
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), output_dir / 'best_model.pt')
else:
pbar.update(1)
pbar.set_postfix({'loss': 'N/A'})
logger.info(f'Best validation loss: {best_valid_loss}')
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
"""Evaluation function with simple metrics."""
model.eval()
total_loss: float = 0.0
total_acc: float = 0.0
num_batches: int = 0
with torch.no_grad():
for batch in dataloader:
outputs = model(**batch)
if outputs["loss"] is not None:
total_loss += outputs["loss"].item()
logits = outputs["logits"]
labels = batch['labels']
predictions = logits.argmax(dim=-1)
total_acc += (predictions == labels).float().mean().item()
num_batches += 1
avg_loss: float = total_loss / max(num_batches, 1)
avg_acc: float = total_acc / max(num_batches, 1)
model.train()
return avg_loss, avg_acc
def main():
# Setup identical to original training script
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir)
config = load_config(current_dir / 'config.json')
setup_wandb(config, exec_file='src/train_tree_codebert_mlm.py')
set_seed(config['seed'])
device = setup_device()
dataset_dir = current_dir.parent / 'data' / 'the-stack-python'
download_dataset(dataset_dir)
dataset = load_and_prepare_dataset(dataset_dir, config['seed'])
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Create tree-enhanced dataloaders
train_dataloader, valid_dataloader = create_tree_dataloaders(dataset, tokenizer, config, device)
# Initialize tree-enhanced model
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = TreeCodeBERTForPreTraining(model_config)
model = model.to(device) # Just move to device without compilation
# model = torch.compile(model)
# Optimizer and scheduler setup identical to original
optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
num_training_steps = config['epochs'] * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
)
# Training loop (using original train_and_evaluate function)
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
if __name__ == "__main__":
main()