on this commit there is a weird loss spike (unresolved)

This commit is contained in:
Patryk Bartkowiak 2024-12-04 16:25:44 +00:00
parent 0cd04e6131
commit b2d65059da
11 changed files with 5211 additions and 723 deletions

View File

@ -7,15 +7,18 @@ authors = [
] ]
dependencies = [ dependencies = [
"wandb==0.18.5", "wandb==0.18.5",
"torch==2.5.0", "torch==2.5.1",
"tqdm==4.66.5", "tqdm==4.66.5",
"tree-sitter==0.23.1", "tree-sitter==0.23.1",
"transformers==4.45.2", "transformers[torch]>=4.46.3",
"datasets==3.0.1", "datasets==3.0.1",
"huggingface-hub==0.26.0", "huggingface-hub==0.26.0",
"matplotlib==3.9.2", "matplotlib==3.9.2",
"scikit-learn==1.5.2", "scikit-learn==1.5.2",
"seaborn==0.13.2", "seaborn==0.13.2",
"tree-sitter-python==0.23.4",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.5",
] ]
requires-python = "==3.11.*" requires-python = "==3.11.*"
readme = "README.md" readme = "README.md"
@ -35,6 +38,5 @@ build-backend = "pdm.backend"
distribution = true distribution = true
[tool.pdm.scripts] [tool.pdm.scripts]
run_training = {cmd = "src/train_codebert_mlm.py"}
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
parse_dataset = {cmd = "src/parse_dataset.py"} parse_dataset = {cmd = "src/parse_dataset.py"}
run_training = {cmd = "src/training.py"}

View File

@ -1,12 +1,13 @@
{ {
"seed": 42, "seed": 420,
"mlm_probability": 0.15, "mlm_probability": 0.15,
"batch": 16, "batch_size": 32,
"epochs": 1, "epochs": 5,
"eval_every": 20000, "eval_every": 1000,
"learning_rate": 5e-4, "learning_rate": 5e-4,
"weight_decay": 0.01, "weight_decay": 0.1,
"max_grad_norm": 1.0, "max_grad_norm": 1.0,
"warmup_steps": 20000, "warmup_steps": 1000,
"train_size": 0.95 "train_size": 0.95,
"fp16": true
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,22 @@
import logging
from pathlib import Path
from typing import List
from huggingface_hub import list_repo_files, hf_hub_download
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def download_dataset(dataset_dir: Path) -> None:
if not dataset_dir.exists():
logger.info("Downloading the dataset...")
dataset_dir.mkdir(parents=True, exist_ok=True)
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
for file_name in files_to_download:
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
logger.info("Dataset downloaded successfully.")
if __name__ == '__main__':
dataset_dir = Path('data/the-stack-python')
download_dataset(dataset_dir)

View File

@ -1,128 +1,123 @@
import ast
from pathlib import Path
import logging import logging
import multiprocessing
import tree_sitter_python as tspython
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
from pprint import pprint from pprint import pprint
from typing import Dict, Any, List from typing import Dict, Any, List
from datasets import load_dataset, Dataset, concatenate_datasets from datasets import load_dataset, Dataset
from transformers import RobertaTokenizer from tree_sitter import Language, Parser
from dataclasses import dataclass from transformers import AutoTokenizer
import numpy as np
import multiprocessing
import warnings import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning) warnings.filterwarnings("ignore", category=SyntaxWarning)
# Setup logging # Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: AutoTokenizer):
class ASTNodeInfo: """
"""Stores structural information about an AST node.""" Map tokens to their original text and Tree-sitter AST position
node_type: str """
start_token_idx: int # Initialize Tree-sitter
end_token_idx: int PY_LANGUAGE = Language(tspython.language())
depth: int parser = Parser(PY_LANGUAGE)
sibling_pos: int
parent_idx: int
def to_dict(self) -> Dict[str, Any]: # Parse with Tree-sitter
return { tree = parser.parse(bytes(code_snippet, "utf8"))
'node_type': self.node_type,
'start_token_idx': self.start_token_idx,
'end_token_idx': self.end_token_idx,
'depth': self.depth,
'sibling_pos': self.sibling_pos,
'parent_idx': self.parent_idx
}
def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo': encoded = tokenizer(
return ASTNodeInfo( code_snippet,
node_type=data['node_type'], add_special_tokens=True,
start_token_idx=data['start_token_idx'], return_offsets_mapping=True,
end_token_idx=data['end_token_idx'], return_tensors='pt',
depth=data['depth'], padding='max_length',
sibling_pos=data['sibling_pos'], truncation=True,
parent_idx=data['parent_idx'] max_length=512,
) )
def safe_ast_parse(content: str, max_length: int = 50000) -> bool: tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
"""Safely attempt to parse Python code.""" offset_mapping = encoded['offset_mapping'][0].tolist()
if not content or not content.strip() or '\0' in content or len(content) > max_length:
return False
try: def get_node_position(node, depth=0, idx=0):
tree = ast.parse(content) """Get depth and sibling index for a node"""
return bool(list(ast.walk(tree))) return (depth, idx)
except:
return False
def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]: def find_node_at_position(node, start_byte, depth=0, sibling_idx=0):
"""Process AST and extract node information.""" """Find the most specific node containing the given position"""
try: if start_byte >= node.start_byte and start_byte < node.end_byte:
tree = ast.parse(content) # Check children first for more specific nodes
nodes_info = [] for idx, child in enumerate(node.children):
result = find_node_at_position(child, start_byte, depth + 1, idx)
if result:
return result
# Return current node's position if no child contains the position
return get_node_position(node, depth, sibling_idx)
return None
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int): depths = []
if len(nodes_info) >= max_ast_size: sibling_idxs = []
return node_texts = []
for token, (start, end) in zip(tokens, offset_mapping):
if token not in ['<s>', '</s>', '<pad>', '<unk>']:
text = code_snippet[start:end] if start < len(code_snippet) else ""
# Tree-sitter works with bytes, so convert position
start_byte = len(code_snippet[:start].encode('utf8')) if start < len(code_snippet) else 0
position = find_node_at_position(tree.root_node, start_byte) or (-1, -1)
current_idx = len(nodes_info) depths.append(position[0])
if hasattr(node, 'lineno'): sibling_idxs.append(position[1])
node_info = ASTNodeInfo( node_texts.append(text)
node_type=type(node).__name__, else:
start_token_idx=node.lineno, depths.append(-1)
end_token_idx=getattr(node, 'end_lineno', node.lineno), sibling_idxs.append(-1)
depth=min(depth, 31), node_texts.append(None)
sibling_pos=min(sibling_pos, 31),
parent_idx=parent_idx
)
nodes_info.append(node_info.to_dict())
if len(nodes_info) < max_ast_size: return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
for i, child in enumerate(ast.iter_child_nodes(node)):
visit_node(child, depth + 1, current_idx, i)
visit_node(tree, 0, -1, 0) def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
return nodes_info
except:
return []
def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]:
"""Process a batch of examples.""" """Process a batch of examples."""
contents = examples['content'] contents = examples['content']
processed_contents = []
processed_ast_nodes = []
processed_input_ids = [] processed_input_ids = []
processed_attention_masks = [] processed_attention_mask = []
processed_tokens = []
processed_depths = []
processed_sibling_idxs = []
processed_node_texts = []
for content in contents: for content in contents:
if safe_ast_parse(content): try:
ast_nodes = process_and_extract_ast_info(content) input_ids, attention_mask, tokens, depths, sibling_idxs, node_texts = analyze_code_with_codebert_and_treesitter(content, tokenizer)
if ast_nodes: processed_input_ids.append(input_ids[0])
try: processed_attention_mask.append(attention_mask[0])
encoding = tokenizer( processed_tokens.append(tokens)
content, processed_depths.append(depths)
max_length=512, processed_sibling_idxs.append(sibling_idxs)
truncation=True, processed_node_texts.append(node_texts)
padding='max_length',
return_tensors='pt'
)
processed_contents.append(content) except Exception as e:
processed_ast_nodes.append(ast_nodes) logger.error(f"Error processing example: {e}")
processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist()) # Return empty lists that will be filtered out later
processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist()) processed_input_ids.append([])
except: processed_attention_mask.append([])
continue processed_tokens.append([])
processed_depths.append([])
processed_sibling_idxs.append([])
processed_node_texts.append([])
return { return {
'content': processed_contents,
'ast_nodes': processed_ast_nodes,
'input_ids': processed_input_ids, 'input_ids': processed_input_ids,
'attention_mask': processed_attention_masks, 'attention_mask': processed_attention_mask,
'tokens': processed_tokens,
'depths': processed_depths,
'sibling_idxs': processed_sibling_idxs,
'node_texts': processed_node_texts,
} }
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000): def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
@ -131,14 +126,9 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
# Create directory for chunks # Create directory for chunks
output_dir = Path(output_path).parent output_dir = Path(output_path).parent
chunks_dir = output_dir / 'chunks' chunks_dir = output_dir / 'data'
chunks_dir.mkdir(exist_ok=True) chunks_dir.mkdir(exist_ok=True)
stats_accumulator = {
'total_ast_nodes': 0,
'total_samples': 0
}
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk") pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
for i in pbar: for i in pbar:
start_idx = i * chunk_size start_idx = i * chunk_size
@ -147,28 +137,10 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
# Select chunk from dataset # Select chunk from dataset
chunk_dataset = dataset.select(range(start_idx, end_idx)) chunk_dataset = dataset.select(range(start_idx, end_idx))
# Update statistics if len(chunk_dataset) > 0: # Only process if chunk has data
stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes']) # Save chunk using datasets native method
stats_accumulator['total_samples'] += len(chunk_dataset) chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet'
chunk_dataset.to_parquet(str(chunk_path))
# Save chunk using datasets native method
chunk_path = chunks_dir / f'chunk_{i:04d}'
chunk_dataset.save_to_disk(str(chunk_path))
# Update progress bar postfix with current chunk info
pbar.set_postfix({
'samples': len(chunk_dataset),
'path': str(chunk_path.name)
})
return stats_accumulator
def load_all_chunks(chunks_dir: Path) -> Dataset:
"""Load and concatenate all dataset chunks."""
chunks = []
for chunk_path in sorted(chunks_dir.glob('chunk_*')):
chunks.append(load_from_disk(str(chunk_path)))
return concatenate_datasets(chunks)
def main(): def main():
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
@ -176,14 +148,23 @@ def main():
output_dir = current_dir.parent / 'data' / 'processed-python' output_dir = current_dir.parent / 'data' / 'processed-python'
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
logger.info("Loading dataset...") logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['train'] dataset = load_dataset(str(input_dir))['train']
original_dataset_size = len(dataset) # dataset = dataset.select(range(200_000)) # Limit dataset size for testing
logger.info("Dataset:") original_dataset_size = len(dataset)
pprint(dataset) logger.info(f"Original dataset size: {original_dataset_size}")
columns_to_remove = dataset.column_names
columns_to_remove.remove('content')
columns_to_remove.remove('size')
columns_to_remove.remove('avg_line_length')
columns_to_remove.remove('max_line_length')
columns_to_remove.remove('alphanum_fraction')
columns_to_remove.remove('max_stars_count')
logging.info(f"Columns to remove: {columns_to_remove}")
num_proc = min(multiprocessing.cpu_count() - 1, 32) num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing") logger.info(f"Using {num_proc} processes for dataset processing")
@ -193,35 +174,38 @@ def main():
process_batch, process_batch,
fn_kwargs={'tokenizer': tokenizer}, fn_kwargs={'tokenizer': tokenizer},
batched=True, batched=True,
remove_columns=dataset.column_names, remove_columns=columns_to_remove,
desc="Processing dataset", desc="Processing dataset",
num_proc=num_proc, num_proc=num_proc,
load_from_cache_file=False load_from_cache_file=False
) )
processed_dataset = processed_dataset.filter( processed_dataset = processed_dataset.filter(
lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']], lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True, batched=True,
num_proc=num_proc, num_proc=num_proc,
desc="Filtering invalid examples" desc="Filtering invalid examples"
) )
reduced_dataset_size = len(processed_dataset) reduced_dataset_size = len(processed_dataset)
logger.info(f"Processed dataset size: {reduced_dataset_size}")
logger.info("Processed dataset:") if reduced_dataset_size == 0:
pprint(processed_dataset) logger.error("No valid examples found in dataset!")
return
logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...") logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...")
stats_accumulator = save_dataset_in_chunks( save_dataset_in_chunks(
processed_dataset, processed_dataset,
str(output_dir / 'processed_dataset'), str(output_dir / 'processed_dataset'),
chunk_size=100_000 chunk_size=100_000
) )
# Add derived statistics
stats = { stats = {
'original_dataset_size': original_dataset_size, 'original_dataset_size': original_dataset_size,
'reduced_dataset_size': reduced_dataset_size, 'reduced_dataset_size': reduced_dataset_size,
'% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100, 'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100,
'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples'])
} }
# Log stats with pprint # Log stats with pprint

View File

@ -0,0 +1,50 @@
import logging
import multiprocessing
from pathlib import Path
from datasets import load_dataset
from parse_dataset import save_dataset_in_chunks
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'processed-python'
output_dir = current_dir.parent / 'data' / 'filtered-python'
output_dir.mkdir(parents=True, exist_ok=True)
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['data']
logging.info(f"Dataset:\n{dataset}")
logger.info("Filtering dataset by max_stars_count > 3...")
filtered_dataset = dataset.filter(
lambda batch: [example and example > 3 for example in batch['max_stars_count']],
num_proc=num_proc,
batched=True,
desc="Filtering dataset"
)
filtered_dataset_size = len(filtered_dataset)
logger.info(f"Filtered dataset size: {filtered_dataset_size}")
if filtered_dataset_size == 0:
logger.error("No examples found with max_stars_count > 3!")
return
logger.info(f"Saving {filtered_dataset_size} filtered examples...")
save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000)
logger.info("Filtering and saving completed!")
if __name__ == "__main__":
main()

View File

@ -1,12 +0,0 @@
{
"seed": 42,
"mlm_probability": 0.15,
"batch": 16,
"epochs": 1,
"eval_every": 5000,
"learning_rate": 5e-4,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"warmup_steps": 5000,
"train_size": 0.95
}

View File

@ -1,50 +0,0 @@
from pathlib import Path
import torch
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer
from training_utils import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, create_optimizer_and_scheduler,
train_and_evaluate, create_base_dataloaders, set_deterministic_mode
)
def main() -> None:
set_deterministic_mode()
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name='base')
config = load_config(current_dir / 'tmp_config.json')
setup_wandb(config, model_name='base')
set_seed(config['seed'])
device = setup_device()
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Create dataloaders
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
device
)
# Model setup
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = RobertaForMaskedLM(model_config)
model = model.to(device)
# Training setup and run
num_training_steps = config['epochs'] * len(train_dataloader)
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
train_and_evaluate(
model, train_dataloader, valid_dataloader,
optimizer, scheduler, config, output_dir,
log_weights=False, device=device
)
if __name__ == "__main__":
main()

View File

@ -1,227 +0,0 @@
import math
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List, Optional
from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM
from training_utils import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, create_optimizer_and_scheduler,
train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset,
set_deterministic_mode
)
from parse_dataset import ASTNodeInfo
class TreePositionalEmbedding(nn.Module):
"""Generates tree-aware positional embeddings for code tokens."""
def __init__(self, d_model: int = 768, max_depth: int = 32):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.combine = nn.Linear(d_model * 2, d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
position = torch.arange(self.max_depth).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
(-math.log(10000.0) / self.d_model))
pe = torch.zeros(self.max_depth, self.d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
with torch.no_grad():
self.depth_embedding.weight.copy_(pe)
self.sibling_embedding.weight.copy_(pe)
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
"""Process batched input with corresponding AST nodes."""
batch_size, seq_len = input_ids.shape
device = input_ids.device
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
# Process each item in the batch
for batch_idx in range(batch_size):
ast_nodes = ast_nodes_batch[batch_idx]
# Process each position in the sequence
for i in range(seq_len):
containing_nodes = [
node for node in ast_nodes
if node.start_token_idx <= i < node.end_token_idx
]
if containing_nodes:
node = max(containing_nodes, key=lambda n: n.depth)
depth = min(node.depth, self.max_depth - 1)
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
return embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize weights with small random values around 0
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
self.beta = nn.Parameter(torch.randn(1) * 0.02)
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
# Add dropout for regularization
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
def get_normalized_weights(self) -> torch.Tensor:
"""
Compute softmax-normalized weights for embedding combination.
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
"""
weights = torch.stack([self.alpha, self.beta, self.gamma])
return torch.softmax(weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
# Move tensors to device
device = input_ids.device
# Get embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get normalized weights
norm_weights = self.get_normalized_weights()
# Combine embeddings based on presence of AST nodes
if ast_nodes is not None:
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
combined_embeddings = (
norm_weights[0] * token_embeddings +
norm_weights[1] * tree_embeddings +
norm_weights[2] * seq_embeddings
)
else:
# Redistribute tree weight to other components when no AST available
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
combined_embeddings = (
token_seq_weights[0] * token_embeddings +
token_seq_weights[1] * seq_embeddings
)
# Apply layer normalization and dropout
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
combined_embeddings = self.embedding_dropout(combined_embeddings)
combined_embeddings = self.final_layer_norm(combined_embeddings)
# Forward pass through transformer
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate loss if labels provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Get normalized weights for logging
norm_weights_cpu = norm_weights.detach().cpu()
return {
"loss": masked_lm_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions,
"embedding_weights": {
"token": norm_weights_cpu[0].item(),
"tree": norm_weights_cpu[1].item(),
"sequential": norm_weights_cpu[2].item()
}
}
def main() -> None:
set_deterministic_mode()
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name='tree')
config = load_config(current_dir / 'tmp_config.json')
setup_wandb(config, model_name='tree')
set_seed(config['seed'])
device = setup_device()
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Create dataloaders with tree dataset
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
device,
dataset_class=PreprocessedTreeDataset
)
# Model setup
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = TreeCodeBERTForPreTraining(model_config)
model = model.to(device)
# Training setup and run
num_training_steps = config['epochs'] * len(train_dataloader)
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
train_and_evaluate(
model, train_dataloader, valid_dataloader,
optimizer, scheduler, config, output_dir,
log_weights=True, device=device
)
if __name__ == "__main__":
main()

306
code/src/training.py Normal file
View File

@ -0,0 +1,306 @@
import math
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Optional
from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM
from training_utils import *
MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom'
DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python'
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False):
super().__init__(config)
self.base_custom = base_custom
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
if base_custom:
initial_weight = math.log(1/2)
self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight))
else:
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None and not self.base_custom:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
if self.base_custom:
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * seq_embeddings
)
else:
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
if self.base_custom:
embedding_weights = {
"token": weights_cpu[0].item(),
"sequential": weights_cpu[1].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
else:
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
"embedding_weights": embedding_weights,
"regularization_metrics": reg_metrics
}
def main() -> None:
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name=MODEL)
config = load_config(current_dir / 'config.json')
set_deterministic_mode(config['seed'])
setup_wandb(config, model_name=MODEL, script_path=__file__)
set_seed(config['seed'])
device = setup_device()
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
# Create dataloaders with tree dataset
logger.info("Creating dataloaders...")
processed_data_dir = current_dir.parent / 'data' / DATASET
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
base_training=True if MODEL == 'base' else False
)
# Initialize model config
logger.info("Initializing model...")
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
# Update W&B
wandb.config.update({'model_config': model_config.__dict__})
# Initialize model
if MODEL == 'tree':
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512)
elif MODEL == 'base':
model = RobertaForMaskedLM(model_config)
elif MODEL == 'base-custom':
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True)
else:
raise ValueError(f"Invalid model type: {MODEL}")
model = model.to(device)
logging.info(f"Model initialized: {MODEL}")
# Model precision
if device.type == 'cuda' and config['fp16']:
model = model.half()
logging.info("Model set to half precision.")
for param in model.parameters():
logging.info(f"Parameter dtype: {param.dtype}")
break
# Training setup
logger.info("Setting up optimizer and scheduler...")
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
# Train and evaluate model
logger.info("Starting training...")
train_and_evaluate(
model=model,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
optimizer=optimizer,
scheduler=scheduler,
config=config,
device=device,
output_dir=output_dir
)
logger.info("Training completed!")
final_output_dir = output_dir / 'final-model'
model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
logger.info(f"Model saved to {final_output_dir}")
if __name__ == "__main__":
main()

View File

@ -8,38 +8,41 @@ import wandb
import numpy as np import numpy as np
import torch import torch
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Tuple, Type, Optional from typing import Dict, Any, Tuple, Optional
from datasets import load_from_disk, concatenate_datasets from datasets import load_dataset
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader
from transformers import ( from transformers import (
PreTrainedModel, PreTrainedModel,
get_linear_schedule_with_warmup, get_constant_schedule_with_warmup,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
PreTrainedTokenizer PreTrainedTokenizer,
RobertaForMaskedLM
) )
from tqdm import tqdm from tqdm import tqdm
from parse_dataset import ASTNodeInfo
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_deterministic_mode() -> None: def set_deterministic_mode(seed: int) -> None:
"""Enable deterministic mode for reproducibility.""" """Enable deterministic mode for reproducibility."""
# Set Python random seed # Set Python random seed
random.seed(42) random.seed(seed)
# Set NumPy random seed # Set NumPy random seed
np.random.seed(42) np.random.seed(seed)
# Set PyTorch random seed # Set PyTorch random seed
torch.manual_seed(42) torch.manual_seed(seed)
# Set CUDA random seed # Set CUDA random seed
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(42) torch.cuda.manual_seed_all(seed)
# Enable deterministic operations for PyTorch 2.x # Enable deterministic operations for PyTorch 2.x
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
@ -47,7 +50,7 @@ def set_deterministic_mode() -> None:
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
# Set environment variables for reproducibility # Set environment variables for reproducibility
os.environ['PYTHONHASHSEED'] = '42' os.environ['PYTHONHASHSEED'] = f'{seed}'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2 os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
def get_device_info() -> Dict[str, Any]: def get_device_info() -> Dict[str, Any]:
@ -70,53 +73,6 @@ def get_device_info() -> Dict[str, Any]:
return info return info
class TreeDataCollator(DataCollatorForLanguageModeling):
"""Custom data collator that handles both MLM and optional AST node information."""
def torch_call(self, examples):
# Check if we have AST nodes
has_ast = 'ast_nodes' in examples[0]
# Extract AST nodes before MLM processing if they exist
ast_nodes = None
if has_ast:
ast_nodes = [e.pop('ast_nodes') for e in examples]
# Process normal MLM features
batch = super().torch_call(examples)
# Add AST nodes back to batch if they existed
if has_ast:
batch['ast_nodes'] = ast_nodes
return batch
def __call__(self, examples):
return self.torch_call(examples)
class PreprocessedBaseDataset(Dataset):
"""Dataset that uses pre-processed data without AST information."""
def __init__(self, dataset: Any):
self.dataset = dataset
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
return {
'input_ids': torch.tensor(item['input_ids']),
'attention_mask': torch.tensor(item['attention_mask']),
'labels': torch.tensor(item['input_ids']).clone()
}
class PreprocessedTreeDataset(PreprocessedBaseDataset):
"""Dataset that includes AST information."""
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']]
return item
def set_seed(seed: int) -> None: def set_seed(seed: int) -> None:
"""Set random seeds for reproducibility.""" """Set random seeds for reproducibility."""
random.seed(seed) random.seed(seed)
@ -128,6 +84,7 @@ def set_seed(seed: int) -> None:
def setup_wandb( def setup_wandb(
config: Dict[str, Any], config: Dict[str, Any],
model_name: str = 'base', model_name: str = 'base',
script_path: Path = None,
device_info: Optional[Dict[str, Any]] = None device_info: Optional[Dict[str, Any]] = None
) -> None: ) -> None:
"""Initialize W&B logging with reproducibility information.""" """Initialize W&B logging with reproducibility information."""
@ -136,17 +93,21 @@ def setup_wandb(
# Add reproducibility settings to config # Add reproducibility settings to config
full_config = { full_config = {
**config, **config,
'random_seed': 42,
'deterministic_mode': True, 'deterministic_mode': True,
'device_info': device_info or get_device_info(), 'device_info': device_info or get_device_info(),
} }
wandb.init( wandb.init(
project='codebert-training-test', project='new-codebert-tree-training',
name=f'{model_name}_{curr_time}', name=f'{model_name}_{curr_time}',
config=full_config config=full_config
) )
# Upload the training script
wandb.save(script_path)
wandb.save(Path(script_path).parent / 'training_utils.py')
logger.info(f'Saving script {script_path} to W&B')
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path: def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
"""Create output directories for model artifacts.""" """Create output directories for model artifacts."""
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
@ -172,7 +133,6 @@ def setup_device() -> torch.device:
def create_optimizer_and_scheduler( def create_optimizer_and_scheduler(
model: PreTrainedModel, model: PreTrainedModel,
config: Dict[str, Any], config: Dict[str, Any],
num_training_steps: int
) -> Tuple[AdamW, Any]: ) -> Tuple[AdamW, Any]:
"""Create optimizer and learning rate scheduler.""" """Create optimizer and learning rate scheduler."""
optimizer = AdamW( optimizer = AdamW(
@ -181,270 +141,215 @@ def create_optimizer_and_scheduler(
weight_decay=config['weight_decay'] weight_decay=config['weight_decay']
) )
scheduler = get_linear_schedule_with_warmup( scheduler = get_constant_schedule_with_warmup(
optimizer, optimizer,
num_warmup_steps=config['warmup_steps'], num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
) )
return optimizer, scheduler return optimizer, scheduler
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]: def evaluate(
"""Evaluate model on validation set.""" model: RobertaForMaskedLM,
dataloader: DataLoader,
device: torch.device
) -> Tuple[float, float]:
"""Evaluate the model on the validation set.
Returns:
Tuple of (average loss, accuracy on masked tokens)
"""
model.eval() model.eval()
total_loss: float = 0.0 total_loss = 0
total_acc: float = 0.0 total_correct = 0
total_predictions = 0
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(dataloader, desc='Validation'): for batch in tqdm(dataloader, desc='Evaluating'):
batch = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
outputs = model(**batch) outputs = model(**batch)
total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item()
logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits']
total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item()
avg_loss: float = total_loss / len(dataloader) # Get loss
avg_acc: float = total_acc / len(dataloader.dataset) total_loss += outputs['loss'].item()
model.train()
return avg_loss, avg_acc # Calculate accuracy only on masked tokens
predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length]
labels = batch['labels'] # [batch_size, seq_length]
# Create mask for tokens we actually want to predict (ignore padding and unmasked tokens)
predict_mask = labels != -100 # -100 is the ignore index
# Calculate accuracy
correct = predictions[predict_mask] == labels[predict_mask]
total_correct += correct.sum().item()
total_predictions += predict_mask.sum().item()
avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
return avg_loss, accuracy
def train_and_evaluate( def train_and_evaluate(
model: PreTrainedModel, model: RobertaForMaskedLM,
train_dataloader: DataLoader, train_dataloader: DataLoader,
valid_dataloader: DataLoader, valid_dataloader: DataLoader,
optimizer: AdamW, optimizer: AdamW,
scheduler: Any, scheduler: Any,
config: Dict[str, Any], config: Dict[str, Any],
device: torch.device,
output_dir: Path, output_dir: Path,
log_weights: bool = False,
device: torch.device = torch.device('cpu')
) -> None: ) -> None:
"""Train and evaluate model with deterministic behavior and comprehensive logging.""" """Train and evaluate the model."""
# Enable deterministic algorithms for PyTorch 2.5
torch.use_deterministic_algorithms(True)
num_training_steps: int = config['epochs'] * len(train_dataloader) model.train()
best_valid_loss: float = float('inf') best_valid_loss = float('inf')
# Save initial model state for reproducibility for epoch in range(config['epochs']):
torch.save({ total_loss = 0
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
}, output_dir / 'initial_state.pt')
with tqdm(total=num_training_steps, desc='Training') as pbar: # Training loop
for epoch_idx in range(config['epochs']): with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
model.train() for step, batch in enumerate(pbar):
epoch_loss = 0.0
epoch_steps = 0
for train_idx, train_batch in enumerate(train_dataloader):
# Move batch to device # Move batch to device
train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v batch = {
for k, v in train_batch.items()} k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
# Forward pass # Forward pass
outputs = model(**train_batch) outputs = model(**batch)
train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss'] loss = outputs['loss']
# Backward pass with deterministic behavior # Backward pass
train_loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), # Clip gradients
config['max_grad_norm'] grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
)
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
optimizer.zero_grad(set_to_none=False) # Use False for determinism optimizer.zero_grad()
# Update metrics # Update metrics
current_loss = train_loss.item() total_loss += loss.item()
epoch_loss += current_loss
epoch_steps += 1
# Calculate global step
step = train_idx + len(train_dataloader) * epoch_idx
# Prepare logging dictionary
log_dict = {
'step': step,
'epoch': epoch_idx,
'train_loss': current_loss,
'gradient_norm': grad_norm.item(),
'learning_rate': scheduler.get_last_lr()[0],
}
# Add embedding weights if using tree model
if log_weights and 'embedding_weights' in outputs:
weights = outputs['embedding_weights']
log_dict.update({
'token_weight': weights['token'],
'tree_weight': weights['tree'],
'sequential_weight': weights['sequential']
})
pbar_dict = {
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
'train_loss': f"{current_loss:.3f}",
'α': f"{weights['token']:.2f}",
'β': f"{weights['tree']:.2f}",
'γ': f"{weights['sequential']:.2f}"
}
else:
pbar_dict = {
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
'train_loss': f"{current_loss:.3f}"
}
# Log to wandb
wandb.log(log_dict)
# Update progress bar # Update progress bar
pbar.update(1) pbar.set_postfix({'loss': f'{loss.item():.4f}'})
pbar.set_postfix(pbar_dict)
# Periodic evaluation # Log to wandb if configured
if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1: log_values = {
model.eval() 'train_loss': loss.item(),
valid_loss, valid_acc = evaluate(model, valid_dataloader) 'grad_norm': grad_norm.item(),
model.train() 'learning_rate': scheduler.get_last_lr()[0]
}
if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']:
log_values.update({
'token_weight': outputs['embedding_weights']['token'],
'tree_weight': outputs['embedding_weights']['tree'],
'seq_weight': outputs['embedding_weights']['sequential'],
})
elif 'embedding_weights' in outputs:
log_values.update({
'token_weight': outputs['embedding_weights']['token'],
'seq_weight': outputs['embedding_weights']['sequential'],
})
if 'regularization_metrics' in outputs:
log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()})
wandb.log(log_values)
# Evaluate periodically
if step % config['eval_every'] == 0 and step > 0:
valid_loss, valid_acc = evaluate(model, valid_dataloader, device)
# Log validation metrics # Log validation metrics
wandb.log({ wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc})
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': step,
'epoch': epoch_idx,
})
# Save checkpoint if best model # Print validation metrics
print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
# Save best model
if valid_loss < best_valid_loss: if valid_loss < best_valid_loss:
best_valid_loss = valid_loss best_valid_loss = valid_loss
model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}')
# Save complete state model.train() # Resume training mode
torch.save({
'epoch': epoch_idx,
'step': step,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'best_valid_loss': best_valid_loss,
'config': config,
'device_info': get_device_info()
}, output_dir / 'best_model.pt')
# Log best metrics
wandb.run.summary['best_valid_loss'] = best_valid_loss
wandb.run.summary['best_valid_acc'] = valid_acc
wandb.run.summary['best_epoch'] = epoch_idx
wandb.run.summary['best_step'] = step
# End of epoch logging
avg_epoch_loss = epoch_loss / epoch_steps
wandb.log({
'epoch': epoch_idx,
'epoch_avg_loss': avg_epoch_loss,
})
# Save end of epoch checkpoint
torch.save({
'epoch': epoch_idx,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'train_loss': avg_epoch_loss,
'config': config,
'device_info': get_device_info()
}, output_dir / f'checkpoint_epoch_{epoch_idx}.pt')
# End of training logging
wandb.run.summary['final_epoch'] = config['epochs'] - 1
wandb.run.summary['total_steps'] = num_training_steps
logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}')
logger.info(f'Model checkpoints saved in {output_dir}')
def load_all_chunks(chunks_dir: Path) -> Any:
"""Load and concatenate all dataset chunks."""
logger.info(f"Loading dataset chunks from {chunks_dir}")
chunks = []
for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]:
chunks.append(load_from_disk(str(chunk_path)))
dataset = concatenate_datasets(chunks)
logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks")
return dataset
def create_base_dataloaders( def create_base_dataloaders(
processed_data_dir: Path, processed_data_dir: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
config: Dict[str, Any], config: Dict[str, Any],
device: torch.device, base_training=False,
dataset_class: Type[Dataset] = PreprocessedBaseDataset,
) -> Tuple[DataLoader, DataLoader]: ) -> Tuple[DataLoader, DataLoader]:
"""Create reproducible dataloaders from pre-processed data.""" """Create dataloaders from pre-processed parquet data."""
# Load chunks
chunks_dir = processed_data_dir / 'chunks' # Load chunks using datasets library
full_dataset = load_all_chunks(chunks_dir) chunks_dir = processed_data_dir / 'data'
dataset = load_dataset(
"parquet",
data_files=str(chunks_dir / "*.parquet"),
split="train"
)
contents = dataset['content'][:config['batch_size']]
with open("contents.txt", "w") as f:
for content in contents:
f.write(content)
f.write("\n\n\n")
f.write("#" * 80)
f.write("\n\n\n")
# Remove columns that are not needed
columns_to_remove = dataset.column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if not base_training:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove)
logging.info(f"Loaded dataset:\n{dataset}")
# Calculate split sizes # Calculate split sizes
dataset_size = len(full_dataset) dataset_size = len(dataset)
train_size = int(config['train_size'] * dataset_size) train_size = int(config['train_size'] * dataset_size)
val_size = dataset_size - train_size val_size = dataset_size - train_size
# Create splits using indices to avoid device issues # Create splits
indices = torch.arange(dataset_size, device=device) splits = dataset.train_test_split(
train_size=train_size,
test_size=val_size,
seed=config['seed']
)
train_dataset = splits['train']
valid_dataset = splits['test']
# Create a deterministic generator on the correct device # Create data collator for MLM
generator = torch.Generator(device=device) data_collator = DataCollatorForLanguageModeling(
generator.manual_seed(42)
# Shuffle indices deterministically
shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)]
train_indices = shuffled_indices[:train_size].cpu()
valid_indices = shuffled_indices[train_size:].cpu()
# Create train/validation splits using the shuffled indices
train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist())
valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist())
# Create dataset wrappers
train_dataset = dataset_class(train_dataset)
valid_dataset = dataset_class(valid_dataset)
logger.info(f"Created train dataset with {len(train_dataset)} samples")
logger.info(f"Created validation dataset with {len(valid_dataset)} samples")
# Use TreeDataCollator for both base and tree models
data_collator = TreeDataCollator(
tokenizer=tokenizer, tokenizer=tokenizer,
mlm=True, mlm=True,
mlm_probability=config['mlm_probability'] mlm_probability=config['mlm_probability']
) )
# Create train dataloader without generator
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_size=config['batch'], batch_size=config['batch_size'],
shuffle=False, # We already shuffled the data shuffle=False, # We'll let the dataset handle shuffling
collate_fn=data_collator, collate_fn=data_collator,
num_workers=0, # Use single worker for reproducibility num_workers=0,
drop_last=True # Ensure consistent batch sizes drop_last=True,
) )
# Create validation dataloader
valid_dataloader = DataLoader( valid_dataloader = DataLoader(
valid_dataset, valid_dataset,
batch_size=config['batch'], batch_size=config['batch_size'],
shuffle=False, shuffle=False,
collate_fn=data_collator, collate_fn=data_collator,
num_workers=0, num_workers=0,
drop_last=True drop_last=True,
) )
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader