added tree enhanced embeddings into training
This commit is contained in:
parent
d4ccfaac44
commit
9299c56bb1
1402
code/pdm.lock
1402
code/pdm.lock
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,12 @@
|
|||||||
{
|
{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"mlm_probability": 0.15,
|
"mlm_probability": 0.15,
|
||||||
"batch": 32,
|
"batch": 16,
|
||||||
"epochs": 1,
|
"epochs": 1,
|
||||||
"eval_every": 10000,
|
"eval_every": 20000,
|
||||||
"learning_rate": 5e-4,
|
"learning_rate": 5e-4,
|
||||||
"weight_decay": 0.01,
|
"weight_decay": 0.01,
|
||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
"warmup_steps": 10000
|
"warmup_steps": 20000,
|
||||||
|
"pause_on_instability": false
|
||||||
}
|
}
|
@ -13,7 +13,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from datasets import load_dataset, disable_caching, DatasetDict
|
from datasets import load_dataset, DatasetDict
|
||||||
from huggingface_hub import list_repo_files, hf_hub_download
|
from huggingface_hub import list_repo_files, hf_hub_download
|
||||||
from transformers import (
|
from transformers import (
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
@ -61,14 +61,14 @@ def set_seed(seed: int) -> None:
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
def setup_wandb(config: Dict[str, Any]) -> None:
|
def setup_wandb(config: Dict[str, Any], exec_file: str = 'train_codebert_mlm.py') -> None:
|
||||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||||
wandb.init(project='codebert-training', name=curr_time, config=config)
|
wandb.init(project='codebert-training', name=f'tree_{curr_time}', config=config)
|
||||||
wandb.save('train_codebert_mlm.py')
|
wandb.save('train_codebert_mlm.py')
|
||||||
|
|
||||||
def setup_directories(current_dir: Path) -> Path:
|
def setup_directories(current_dir: Path) -> Path:
|
||||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||||
output_dir: Path = current_dir.parent.parent / 'outputs' / curr_time
|
output_dir: Path = current_dir.parent.parent / 'outputs' / f'tree_{curr_time}'
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return output_dir
|
return output_dir
|
||||||
|
|
||||||
@ -130,10 +130,11 @@ def create_dataloaders(
|
|||||||
|
|
||||||
def setup_model_and_optimizer(
|
def setup_model_and_optimizer(
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
current_dir: Path
|
models_dir: Path
|
||||||
) -> Tuple[PreTrainedModel, AdamW]:
|
) -> Tuple[PreTrainedModel, AdamW]:
|
||||||
os.environ['HF_HOME'] = str(current_dir.parent / 'models')
|
if not models_dir.exists():
|
||||||
model_config = RobertaConfig.from_pretrained('roberta-base')
|
logger.info("Downloading the model...")
|
||||||
|
model_config = RobertaConfig.from_pretrained('roberta-base', cache_dir=models_dir)
|
||||||
model: PreTrainedModel = RobertaForMaskedLM(model_config)
|
model: PreTrainedModel = RobertaForMaskedLM(model_config)
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
wandb.watch(model)
|
wandb.watch(model)
|
||||||
@ -212,8 +213,6 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo
|
|||||||
return avg_loss, avg_acc
|
return avg_loss, avg_acc
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
disable_caching()
|
|
||||||
|
|
||||||
current_dir: Path = Path(__file__).parent
|
current_dir: Path = Path(__file__).parent
|
||||||
output_dir: Path = setup_directories(current_dir)
|
output_dir: Path = setup_directories(current_dir)
|
||||||
config: Dict[str, Any] = load_config(current_dir / 'config.json')
|
config: Dict[str, Any] = load_config(current_dir / 'config.json')
|
||||||
@ -229,9 +228,15 @@ def main() -> None:
|
|||||||
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
||||||
logger.info(f'Tokenizer loaded: {tokenizer}')
|
logger.info(f'Tokenizer loaded: {tokenizer}')
|
||||||
|
|
||||||
|
######################## Reproducing last training here ########################
|
||||||
|
# Remove first 186513 batches
|
||||||
|
dataset['train'] = dataset['train'].select(range(186_513 * 32, len(dataset['train'])))
|
||||||
|
################################################################################
|
||||||
|
|
||||||
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
|
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
|
||||||
|
|
||||||
model, optimizer = setup_model_and_optimizer(config, current_dir)
|
models_dir: Path = current_dir.parent / 'models' / 'roberta-base'
|
||||||
|
model, optimizer = setup_model_and_optimizer(config, models_dir)
|
||||||
|
|
||||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
@ -240,6 +245,16 @@ def main() -> None:
|
|||||||
num_training_steps=num_training_steps
|
num_training_steps=num_training_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
######################## Reproducing last training here ########################
|
||||||
|
# Change opitmizer learning rate to 0.00021575814536340852
|
||||||
|
optimizer = AdamW(model.parameters(), lr=0.00021575814536340852, weight_decay=config['weight_decay'])
|
||||||
|
# Set warmup_steps to 0
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
|
||||||
|
# Load the best model weights
|
||||||
|
state_dict = torch.load('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt', weights_only=True, map_location=device)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
################################################################################
|
||||||
|
|
||||||
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
494
code/src/train_tree_codebert_mlm.py
Normal file
494
code/src/train_tree_codebert_mlm.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
import wandb
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
import ast
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from datasets import DatasetDict
|
||||||
|
from transformers import (
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaTokenizer,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.setrecursionlimit(3000) # Increase recursion limit
|
||||||
|
|
||||||
|
# Import existing training functionality
|
||||||
|
from train_codebert_mlm import (
|
||||||
|
set_seed, setup_wandb, setup_directories, load_config,
|
||||||
|
setup_device, download_dataset, load_and_prepare_dataset, logger
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ASTNodeInfo:
|
||||||
|
"""Stores structural information about an AST node."""
|
||||||
|
node_type: str
|
||||||
|
start_token_idx: int
|
||||||
|
end_token_idx: int
|
||||||
|
depth: int
|
||||||
|
sibling_pos: int
|
||||||
|
parent_idx: int
|
||||||
|
|
||||||
|
def ast_collate_fn(batch):
|
||||||
|
"""Custom collate function with improved error handling."""
|
||||||
|
# Remove failed parses (where attention_mask is all zeros)
|
||||||
|
valid_batch = [
|
||||||
|
item for item in batch
|
||||||
|
if item['attention_mask'].sum() > 0 and len(item['ast_nodes']) > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
if not valid_batch:
|
||||||
|
# Return minimal batch if no valid items
|
||||||
|
return {
|
||||||
|
'input_ids': torch.zeros((1, 512), dtype=torch.long),
|
||||||
|
'attention_mask': torch.zeros((1, 512), dtype=torch.long),
|
||||||
|
'labels': torch.zeros((1, 512), dtype=torch.long),
|
||||||
|
'ast_nodes': [[]]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stack tensors
|
||||||
|
input_ids = torch.stack([item['input_ids'] for item in valid_batch])
|
||||||
|
attention_mask = torch.stack([item['attention_mask'] for item in valid_batch])
|
||||||
|
labels = torch.stack([item['labels'] for item in valid_batch])
|
||||||
|
|
||||||
|
# Collect AST nodes
|
||||||
|
ast_nodes = [item['ast_nodes'] for item in valid_batch]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'attention_mask': attention_mask,
|
||||||
|
'labels': labels,
|
||||||
|
'ast_nodes': ast_nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
class TreePositionalEmbedding(nn.Module):
|
||||||
|
"""Generates tree-aware positional embeddings for code tokens."""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int = 768, max_depth: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.max_depth = max_depth
|
||||||
|
|
||||||
|
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
self.combine = nn.Linear(d_model * 2, d_model)
|
||||||
|
|
||||||
|
self._initialize_embeddings()
|
||||||
|
|
||||||
|
def _initialize_embeddings(self):
|
||||||
|
position = torch.arange(self.max_depth).unsqueeze(1).float()
|
||||||
|
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
|
||||||
|
(-math.log(10000.0) / self.d_model))
|
||||||
|
|
||||||
|
pe = torch.zeros(self.max_depth, self.d_model)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.depth_embedding.weight.copy_(pe)
|
||||||
|
self.sibling_embedding.weight.copy_(pe)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
|
||||||
|
"""Process batched input with corresponding AST nodes."""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
|
||||||
|
|
||||||
|
# Process each item in the batch
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
ast_nodes = ast_nodes_batch[batch_idx]
|
||||||
|
# Process each position in the sequence
|
||||||
|
for i in range(seq_len):
|
||||||
|
containing_nodes = [
|
||||||
|
node for node in ast_nodes
|
||||||
|
if node.start_token_idx <= i < node.end_token_idx
|
||||||
|
]
|
||||||
|
|
||||||
|
if containing_nodes:
|
||||||
|
node = max(containing_nodes, key=lambda n: n.depth)
|
||||||
|
depth = min(node.depth, self.max_depth - 1)
|
||||||
|
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
|
||||||
|
|
||||||
|
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
|
||||||
|
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
|
||||||
|
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||||
|
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
|
||||||
|
|
||||||
|
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||||
|
d_model=config.hidden_size,
|
||||||
|
max_depth=max_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||||
|
|
||||||
|
# Initialize sequential position embeddings with sinusoidal pattern
|
||||||
|
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||||
|
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||||
|
|
||||||
|
# Initialize weights with small random values around 0
|
||||||
|
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
|
||||||
|
self.beta = nn.Parameter(torch.randn(1) * 0.02)
|
||||||
|
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
|
||||||
|
|
||||||
|
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
|
# Add dropout for regularization
|
||||||
|
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def get_normalized_weights(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute softmax-normalized weights for embedding combination.
|
||||||
|
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
|
||||||
|
"""
|
||||||
|
weights = torch.stack([self.alpha, self.beta, self.gamma])
|
||||||
|
return torch.softmax(weights, dim=0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
# Move tensors to device
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||||
|
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||||
|
|
||||||
|
# Get normalized weights
|
||||||
|
norm_weights = self.get_normalized_weights()
|
||||||
|
|
||||||
|
# Combine embeddings based on presence of AST nodes
|
||||||
|
if ast_nodes is not None:
|
||||||
|
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
|
||||||
|
combined_embeddings = (
|
||||||
|
norm_weights[0] * token_embeddings +
|
||||||
|
norm_weights[1] * tree_embeddings +
|
||||||
|
norm_weights[2] * seq_embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Redistribute tree weight to other components when no AST available
|
||||||
|
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
|
||||||
|
combined_embeddings = (
|
||||||
|
token_seq_weights[0] * token_embeddings +
|
||||||
|
token_seq_weights[1] * seq_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply layer normalization and dropout
|
||||||
|
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
|
||||||
|
combined_embeddings = self.embedding_dropout(combined_embeddings)
|
||||||
|
combined_embeddings = self.final_layer_norm(combined_embeddings)
|
||||||
|
|
||||||
|
# Forward pass through transformer
|
||||||
|
outputs = self.roberta(
|
||||||
|
inputs_embeds=combined_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.lm_head(sequence_output)
|
||||||
|
|
||||||
|
# Calculate loss if labels provided
|
||||||
|
masked_lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
masked_lm_loss = loss_fct(
|
||||||
|
prediction_scores.view(-1, self.config.vocab_size),
|
||||||
|
labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get normalized weights for logging
|
||||||
|
norm_weights_cpu = norm_weights.detach().cpu()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": masked_lm_loss,
|
||||||
|
"logits": prediction_scores,
|
||||||
|
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||||
|
"attentions": outputs.attentions,
|
||||||
|
"embedding_weights": {
|
||||||
|
"token": norm_weights_cpu[0].item(),
|
||||||
|
"tree": norm_weights_cpu[1].item(),
|
||||||
|
"sequential": norm_weights_cpu[2].item()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class TreeEnhancedDataset(Dataset):
|
||||||
|
"""Dataset that processes code into tokens and AST nodes with improved error handling."""
|
||||||
|
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.max_ast_size = 1000 # Limit maximum AST size to prevent memory issues
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def create_dummy_data(self) -> Dict[str, Any]:
|
||||||
|
"""Create dummy data for invalid/problematic code samples."""
|
||||||
|
pad_token_id = self.tokenizer.pad_token_id or 0
|
||||||
|
dummy_tensor = torch.full((self.max_length,), pad_token_id, dtype=torch.long)
|
||||||
|
return {
|
||||||
|
'input_ids': dummy_tensor,
|
||||||
|
'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
|
||||||
|
'labels': dummy_tensor.clone(),
|
||||||
|
'ast_nodes': []
|
||||||
|
}
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
content: str = self.dataset[idx]['content']
|
||||||
|
|
||||||
|
# Skip extremely long files
|
||||||
|
if len(content) > 50000: # ~1000 lines
|
||||||
|
return self.create_dummy_data()
|
||||||
|
|
||||||
|
# Basic code validation
|
||||||
|
if not content.strip() or '\0' in content:
|
||||||
|
return self.create_dummy_data()
|
||||||
|
|
||||||
|
# Tokenize first
|
||||||
|
encoding = self.tokenizer(
|
||||||
|
content,
|
||||||
|
max_length=self.max_length,
|
||||||
|
truncation=True,
|
||||||
|
padding='max_length',
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse AST with timeout
|
||||||
|
tree = ast.parse(content)
|
||||||
|
|
||||||
|
# Get AST nodes info
|
||||||
|
nodes_info = []
|
||||||
|
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
|
||||||
|
if len(nodes_info) >= self.max_ast_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_idx = len(nodes_info)
|
||||||
|
if hasattr(node, 'lineno'):
|
||||||
|
nodes_info.append(ASTNodeInfo(
|
||||||
|
node_type=type(node).__name__,
|
||||||
|
start_token_idx=node.lineno,
|
||||||
|
end_token_idx=getattr(node, 'end_lineno', node.lineno),
|
||||||
|
depth=min(depth, 31), # Limit depth to prevent issues
|
||||||
|
sibling_pos=min(sibling_pos, 31),
|
||||||
|
parent_idx=parent_idx
|
||||||
|
))
|
||||||
|
|
||||||
|
# Only process children if we haven't hit the limit
|
||||||
|
if len(nodes_info) < self.max_ast_size:
|
||||||
|
for i, child in enumerate(ast.iter_child_nodes(node)):
|
||||||
|
visit_node(child, depth + 1, current_idx, i)
|
||||||
|
|
||||||
|
visit_node(tree, 0, -1, 0)
|
||||||
|
|
||||||
|
# If we hit the AST size limit, use dummy data
|
||||||
|
if len(nodes_info) >= self.max_ast_size:
|
||||||
|
return self.create_dummy_data()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': encoding['input_ids'].squeeze(0),
|
||||||
|
'attention_mask': encoding['attention_mask'].squeeze(0),
|
||||||
|
'labels': encoding['input_ids'].squeeze(0).clone(),
|
||||||
|
'ast_nodes': nodes_info
|
||||||
|
}
|
||||||
|
|
||||||
|
except (SyntaxError, ValueError, RecursionError, TimeoutError, MemoryError):
|
||||||
|
return self.create_dummy_data()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_dummy_data()
|
||||||
|
|
||||||
|
def create_tree_dataloaders(
|
||||||
|
dataset: DatasetDict,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
device: torch.device
|
||||||
|
) -> Tuple[DataLoader, DataLoader]:
|
||||||
|
"""Create dataloaders with tree-enhanced datasets."""
|
||||||
|
train_dataset = TreeEnhancedDataset(dataset['train'], tokenizer, max_length=512)
|
||||||
|
valid_dataset = TreeEnhancedDataset(dataset['test'], tokenizer, max_length=512)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config['batch'],
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=ast_collate_fn,
|
||||||
|
generator=torch.Generator(device=device)
|
||||||
|
)
|
||||||
|
valid_dataloader = DataLoader(
|
||||||
|
valid_dataset,
|
||||||
|
batch_size=config['batch'],
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=ast_collate_fn,
|
||||||
|
generator=torch.Generator(device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataloader, valid_dataloader
|
||||||
|
|
||||||
|
def train_and_evaluate(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
valid_dataloader: DataLoader,
|
||||||
|
optimizer: AdamW,
|
||||||
|
scheduler: Any,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
output_dir: Path
|
||||||
|
) -> None:
|
||||||
|
"""Training loop with explicit tracking of alpha, beta, gamma weights."""
|
||||||
|
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||||
|
best_valid_loss: float = float('inf')
|
||||||
|
|
||||||
|
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||||
|
for epoch_idx in range(config['epochs']):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for train_idx, train_batch in enumerate(train_dataloader):
|
||||||
|
outputs = model(**train_batch)
|
||||||
|
train_loss = outputs["loss"]
|
||||||
|
|
||||||
|
if train_loss is not None:
|
||||||
|
train_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Get current metrics
|
||||||
|
current_loss = train_loss.item()
|
||||||
|
weights = outputs["embedding_weights"]
|
||||||
|
|
||||||
|
# Update progress bar with all three weights
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_postfix({
|
||||||
|
'loss': f"{current_loss:.3f}",
|
||||||
|
'α': f"{weights['token']:.2f}",
|
||||||
|
'β': f"{weights['tree']:.2f}",
|
||||||
|
'γ': f"{weights['sequential']:.2f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Log all three weights separately
|
||||||
|
step = train_idx + len(train_dataloader) * epoch_idx
|
||||||
|
wandb.log({
|
||||||
|
'loss': current_loss,
|
||||||
|
'token_weight': weights['token'],
|
||||||
|
'tree_weight': weights['tree'],
|
||||||
|
'sequential_weight': weights['sequential'],
|
||||||
|
'step': step,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Periodic evaluation
|
||||||
|
if train_idx != 0 and train_idx % config['eval_every'] == 0:
|
||||||
|
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||||
|
|
||||||
|
wandb.log({
|
||||||
|
'valid_loss': valid_loss,
|
||||||
|
'valid_acc': valid_acc,
|
||||||
|
'step': step,
|
||||||
|
})
|
||||||
|
|
||||||
|
if valid_loss < best_valid_loss:
|
||||||
|
best_valid_loss = valid_loss
|
||||||
|
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||||
|
else:
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_postfix({'loss': 'N/A'})
|
||||||
|
|
||||||
|
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||||
|
|
||||||
|
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||||
|
"""Evaluation function with simple metrics."""
|
||||||
|
model.eval()
|
||||||
|
total_loss: float = 0.0
|
||||||
|
total_acc: float = 0.0
|
||||||
|
num_batches: int = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in dataloader:
|
||||||
|
outputs = model(**batch)
|
||||||
|
if outputs["loss"] is not None:
|
||||||
|
total_loss += outputs["loss"].item()
|
||||||
|
logits = outputs["logits"]
|
||||||
|
labels = batch['labels']
|
||||||
|
predictions = logits.argmax(dim=-1)
|
||||||
|
total_acc += (predictions == labels).float().mean().item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss: float = total_loss / max(num_batches, 1)
|
||||||
|
avg_acc: float = total_acc / max(num_batches, 1)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
return avg_loss, avg_acc
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Setup identical to original training script
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
output_dir = setup_directories(current_dir)
|
||||||
|
config = load_config(current_dir / 'config.json')
|
||||||
|
|
||||||
|
setup_wandb(config, exec_file='src/train_tree_codebert_mlm.py')
|
||||||
|
set_seed(config['seed'])
|
||||||
|
device = setup_device()
|
||||||
|
|
||||||
|
dataset_dir = current_dir.parent / 'data' / 'the-stack-python'
|
||||||
|
download_dataset(dataset_dir)
|
||||||
|
dataset = load_and_prepare_dataset(dataset_dir, config['seed'])
|
||||||
|
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||||
|
|
||||||
|
# Create tree-enhanced dataloaders
|
||||||
|
train_dataloader, valid_dataloader = create_tree_dataloaders(dataset, tokenizer, config, device)
|
||||||
|
|
||||||
|
# Initialize tree-enhanced model
|
||||||
|
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||||
|
model = TreeCodeBERTForPreTraining(model_config)
|
||||||
|
model = model.to(device) # Just move to device without compilation
|
||||||
|
# model = torch.compile(model)
|
||||||
|
|
||||||
|
# Optimizer and scheduler setup identical to original
|
||||||
|
optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
|
||||||
|
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||||
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=config['warmup_steps'],
|
||||||
|
num_training_steps=num_training_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop (using original train_and_evaluate function)
|
||||||
|
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user