on this commit there is a weird loss spike (unresolved)

This commit is contained in:
Patryk Bartkowiak 2024-12-04 16:25:44 +00:00
parent 0cd04e6131
commit b2d65059da
11 changed files with 5211 additions and 723 deletions

View File

@ -7,15 +7,18 @@ authors = [
]
dependencies = [
"wandb==0.18.5",
"torch==2.5.0",
"torch==2.5.1",
"tqdm==4.66.5",
"tree-sitter==0.23.1",
"transformers==4.45.2",
"transformers[torch]>=4.46.3",
"datasets==3.0.1",
"huggingface-hub==0.26.0",
"matplotlib==3.9.2",
"scikit-learn==1.5.2",
"seaborn==0.13.2",
"tree-sitter-python==0.23.4",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.5",
]
requires-python = "==3.11.*"
readme = "README.md"
@ -35,6 +38,5 @@ build-backend = "pdm.backend"
distribution = true
[tool.pdm.scripts]
run_training = {cmd = "src/train_codebert_mlm.py"}
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
parse_dataset = {cmd = "src/parse_dataset.py"}
run_training = {cmd = "src/training.py"}

View File

@ -1,12 +1,13 @@
{
"seed": 42,
"seed": 420,
"mlm_probability": 0.15,
"batch": 16,
"epochs": 1,
"eval_every": 20000,
"batch_size": 32,
"epochs": 5,
"eval_every": 1000,
"learning_rate": 5e-4,
"weight_decay": 0.01,
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 20000,
"train_size": 0.95
"warmup_steps": 1000,
"train_size": 0.95,
"fp16": true
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,22 @@
import logging
from pathlib import Path
from typing import List
from huggingface_hub import list_repo_files, hf_hub_download
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def download_dataset(dataset_dir: Path) -> None:
if not dataset_dir.exists():
logger.info("Downloading the dataset...")
dataset_dir.mkdir(parents=True, exist_ok=True)
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
for file_name in files_to_download:
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
logger.info("Dataset downloaded successfully.")
if __name__ == '__main__':
dataset_dir = Path('data/the-stack-python')
download_dataset(dataset_dir)

View File

@ -1,128 +1,123 @@
import ast
from pathlib import Path
import logging
import multiprocessing
import tree_sitter_python as tspython
from tqdm import tqdm
from pathlib import Path
from pprint import pprint
from typing import Dict, Any, List
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import RobertaTokenizer
from dataclasses import dataclass
import numpy as np
import multiprocessing
from datasets import load_dataset, Dataset
from tree_sitter import Language, Parser
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
@dataclass
class ASTNodeInfo:
"""Stores structural information about an AST node."""
node_type: str
start_token_idx: int
end_token_idx: int
depth: int
sibling_pos: int
parent_idx: int
def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: AutoTokenizer):
"""
Map tokens to their original text and Tree-sitter AST position
"""
# Initialize Tree-sitter
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)
def to_dict(self) -> Dict[str, Any]:
return {
'node_type': self.node_type,
'start_token_idx': self.start_token_idx,
'end_token_idx': self.end_token_idx,
'depth': self.depth,
'sibling_pos': self.sibling_pos,
'parent_idx': self.parent_idx
}
# Parse with Tree-sitter
tree = parser.parse(bytes(code_snippet, "utf8"))
def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo':
return ASTNodeInfo(
node_type=data['node_type'],
start_token_idx=data['start_token_idx'],
end_token_idx=data['end_token_idx'],
depth=data['depth'],
sibling_pos=data['sibling_pos'],
parent_idx=data['parent_idx']
encoded = tokenizer(
code_snippet,
add_special_tokens=True,
return_offsets_mapping=True,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=512,
)
def safe_ast_parse(content: str, max_length: int = 50000) -> bool:
"""Safely attempt to parse Python code."""
if not content or not content.strip() or '\0' in content or len(content) > max_length:
return False
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
offset_mapping = encoded['offset_mapping'][0].tolist()
try:
tree = ast.parse(content)
return bool(list(ast.walk(tree)))
except:
return False
def get_node_position(node, depth=0, idx=0):
"""Get depth and sibling index for a node"""
return (depth, idx)
def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]:
"""Process AST and extract node information."""
try:
tree = ast.parse(content)
nodes_info = []
def find_node_at_position(node, start_byte, depth=0, sibling_idx=0):
"""Find the most specific node containing the given position"""
if start_byte >= node.start_byte and start_byte < node.end_byte:
# Check children first for more specific nodes
for idx, child in enumerate(node.children):
result = find_node_at_position(child, start_byte, depth + 1, idx)
if result:
return result
# Return current node's position if no child contains the position
return get_node_position(node, depth, sibling_idx)
return None
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
if len(nodes_info) >= max_ast_size:
return
depths = []
sibling_idxs = []
node_texts = []
for token, (start, end) in zip(tokens, offset_mapping):
if token not in ['<s>', '</s>', '<pad>', '<unk>']:
text = code_snippet[start:end] if start < len(code_snippet) else ""
# Tree-sitter works with bytes, so convert position
start_byte = len(code_snippet[:start].encode('utf8')) if start < len(code_snippet) else 0
position = find_node_at_position(tree.root_node, start_byte) or (-1, -1)
current_idx = len(nodes_info)
if hasattr(node, 'lineno'):
node_info = ASTNodeInfo(
node_type=type(node).__name__,
start_token_idx=node.lineno,
end_token_idx=getattr(node, 'end_lineno', node.lineno),
depth=min(depth, 31),
sibling_pos=min(sibling_pos, 31),
parent_idx=parent_idx
)
nodes_info.append(node_info.to_dict())
depths.append(position[0])
sibling_idxs.append(position[1])
node_texts.append(text)
else:
depths.append(-1)
sibling_idxs.append(-1)
node_texts.append(None)
if len(nodes_info) < max_ast_size:
for i, child in enumerate(ast.iter_child_nodes(node)):
visit_node(child, depth + 1, current_idx, i)
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
visit_node(tree, 0, -1, 0)
return nodes_info
except:
return []
def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]:
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
"""Process a batch of examples."""
contents = examples['content']
processed_contents = []
processed_ast_nodes = []
processed_input_ids = []
processed_attention_masks = []
processed_attention_mask = []
processed_tokens = []
processed_depths = []
processed_sibling_idxs = []
processed_node_texts = []
for content in contents:
if safe_ast_parse(content):
ast_nodes = process_and_extract_ast_info(content)
if ast_nodes:
try:
encoding = tokenizer(
content,
max_length=512,
truncation=True,
padding='max_length',
return_tensors='pt'
)
input_ids, attention_mask, tokens, depths, sibling_idxs, node_texts = analyze_code_with_codebert_and_treesitter(content, tokenizer)
processed_input_ids.append(input_ids[0])
processed_attention_mask.append(attention_mask[0])
processed_tokens.append(tokens)
processed_depths.append(depths)
processed_sibling_idxs.append(sibling_idxs)
processed_node_texts.append(node_texts)
processed_contents.append(content)
processed_ast_nodes.append(ast_nodes)
processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist())
processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist())
except:
continue
except Exception as e:
logger.error(f"Error processing example: {e}")
# Return empty lists that will be filtered out later
processed_input_ids.append([])
processed_attention_mask.append([])
processed_tokens.append([])
processed_depths.append([])
processed_sibling_idxs.append([])
processed_node_texts.append([])
return {
'content': processed_contents,
'ast_nodes': processed_ast_nodes,
'input_ids': processed_input_ids,
'attention_mask': processed_attention_masks,
'attention_mask': processed_attention_mask,
'tokens': processed_tokens,
'depths': processed_depths,
'sibling_idxs': processed_sibling_idxs,
'node_texts': processed_node_texts,
}
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
@ -131,14 +126,9 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
# Create directory for chunks
output_dir = Path(output_path).parent
chunks_dir = output_dir / 'chunks'
chunks_dir = output_dir / 'data'
chunks_dir.mkdir(exist_ok=True)
stats_accumulator = {
'total_ast_nodes': 0,
'total_samples': 0
}
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
for i in pbar:
start_idx = i * chunk_size
@ -147,28 +137,10 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
# Select chunk from dataset
chunk_dataset = dataset.select(range(start_idx, end_idx))
# Update statistics
stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes'])
stats_accumulator['total_samples'] += len(chunk_dataset)
if len(chunk_dataset) > 0: # Only process if chunk has data
# Save chunk using datasets native method
chunk_path = chunks_dir / f'chunk_{i:04d}'
chunk_dataset.save_to_disk(str(chunk_path))
# Update progress bar postfix with current chunk info
pbar.set_postfix({
'samples': len(chunk_dataset),
'path': str(chunk_path.name)
})
return stats_accumulator
def load_all_chunks(chunks_dir: Path) -> Dataset:
"""Load and concatenate all dataset chunks."""
chunks = []
for chunk_path in sorted(chunks_dir.glob('chunk_*')):
chunks.append(load_from_disk(str(chunk_path)))
return concatenate_datasets(chunks)
chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet'
chunk_dataset.to_parquet(str(chunk_path))
def main():
current_dir = Path(__file__).parent
@ -176,14 +148,23 @@ def main():
output_dir = current_dir.parent / 'data' / 'processed-python'
output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['train']
original_dataset_size = len(dataset)
# dataset = dataset.select(range(200_000)) # Limit dataset size for testing
logger.info("Dataset:")
pprint(dataset)
original_dataset_size = len(dataset)
logger.info(f"Original dataset size: {original_dataset_size}")
columns_to_remove = dataset.column_names
columns_to_remove.remove('content')
columns_to_remove.remove('size')
columns_to_remove.remove('avg_line_length')
columns_to_remove.remove('max_line_length')
columns_to_remove.remove('alphanum_fraction')
columns_to_remove.remove('max_stars_count')
logging.info(f"Columns to remove: {columns_to_remove}")
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
@ -193,35 +174,38 @@ def main():
process_batch,
fn_kwargs={'tokenizer': tokenizer},
batched=True,
remove_columns=dataset.column_names,
remove_columns=columns_to_remove,
desc="Processing dataset",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']],
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
reduced_dataset_size = len(processed_dataset)
logger.info(f"Processed dataset size: {reduced_dataset_size}")
logger.info("Processed dataset:")
pprint(processed_dataset)
if reduced_dataset_size == 0:
logger.error("No valid examples found in dataset!")
return
logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...")
stats_accumulator = save_dataset_in_chunks(
logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...")
save_dataset_in_chunks(
processed_dataset,
str(output_dir / 'processed_dataset'),
chunk_size=100_000
)
# Add derived statistics
stats = {
'original_dataset_size': original_dataset_size,
'reduced_dataset_size': reduced_dataset_size,
'% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100,
'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples'])
'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100,
}
# Log stats with pprint

View File

@ -0,0 +1,50 @@
import logging
import multiprocessing
from pathlib import Path
from datasets import load_dataset
from parse_dataset import save_dataset_in_chunks
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'processed-python'
output_dir = current_dir.parent / 'data' / 'filtered-python'
output_dir.mkdir(parents=True, exist_ok=True)
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['data']
logging.info(f"Dataset:\n{dataset}")
logger.info("Filtering dataset by max_stars_count > 3...")
filtered_dataset = dataset.filter(
lambda batch: [example and example > 3 for example in batch['max_stars_count']],
num_proc=num_proc,
batched=True,
desc="Filtering dataset"
)
filtered_dataset_size = len(filtered_dataset)
logger.info(f"Filtered dataset size: {filtered_dataset_size}")
if filtered_dataset_size == 0:
logger.error("No examples found with max_stars_count > 3!")
return
logger.info(f"Saving {filtered_dataset_size} filtered examples...")
save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000)
logger.info("Filtering and saving completed!")
if __name__ == "__main__":
main()

View File

@ -1,12 +0,0 @@
{
"seed": 42,
"mlm_probability": 0.15,
"batch": 16,
"epochs": 1,
"eval_every": 5000,
"learning_rate": 5e-4,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"warmup_steps": 5000,
"train_size": 0.95
}

View File

@ -1,50 +0,0 @@
from pathlib import Path
import torch
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer
from training_utils import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, create_optimizer_and_scheduler,
train_and_evaluate, create_base_dataloaders, set_deterministic_mode
)
def main() -> None:
set_deterministic_mode()
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name='base')
config = load_config(current_dir / 'tmp_config.json')
setup_wandb(config, model_name='base')
set_seed(config['seed'])
device = setup_device()
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Create dataloaders
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
device
)
# Model setup
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = RobertaForMaskedLM(model_config)
model = model.to(device)
# Training setup and run
num_training_steps = config['epochs'] * len(train_dataloader)
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
train_and_evaluate(
model, train_dataloader, valid_dataloader,
optimizer, scheduler, config, output_dir,
log_weights=False, device=device
)
if __name__ == "__main__":
main()

View File

@ -1,227 +0,0 @@
import math
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List, Optional
from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM
from training_utils import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, create_optimizer_and_scheduler,
train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset,
set_deterministic_mode
)
from parse_dataset import ASTNodeInfo
class TreePositionalEmbedding(nn.Module):
"""Generates tree-aware positional embeddings for code tokens."""
def __init__(self, d_model: int = 768, max_depth: int = 32):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.combine = nn.Linear(d_model * 2, d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
position = torch.arange(self.max_depth).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
(-math.log(10000.0) / self.d_model))
pe = torch.zeros(self.max_depth, self.d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
with torch.no_grad():
self.depth_embedding.weight.copy_(pe)
self.sibling_embedding.weight.copy_(pe)
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
"""Process batched input with corresponding AST nodes."""
batch_size, seq_len = input_ids.shape
device = input_ids.device
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
# Process each item in the batch
for batch_idx in range(batch_size):
ast_nodes = ast_nodes_batch[batch_idx]
# Process each position in the sequence
for i in range(seq_len):
containing_nodes = [
node for node in ast_nodes
if node.start_token_idx <= i < node.end_token_idx
]
if containing_nodes:
node = max(containing_nodes, key=lambda n: n.depth)
depth = min(node.depth, self.max_depth - 1)
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
return embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize weights with small random values around 0
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
self.beta = nn.Parameter(torch.randn(1) * 0.02)
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
# Add dropout for regularization
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
def get_normalized_weights(self) -> torch.Tensor:
"""
Compute softmax-normalized weights for embedding combination.
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
"""
weights = torch.stack([self.alpha, self.beta, self.gamma])
return torch.softmax(weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
# Move tensors to device
device = input_ids.device
# Get embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get normalized weights
norm_weights = self.get_normalized_weights()
# Combine embeddings based on presence of AST nodes
if ast_nodes is not None:
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
combined_embeddings = (
norm_weights[0] * token_embeddings +
norm_weights[1] * tree_embeddings +
norm_weights[2] * seq_embeddings
)
else:
# Redistribute tree weight to other components when no AST available
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
combined_embeddings = (
token_seq_weights[0] * token_embeddings +
token_seq_weights[1] * seq_embeddings
)
# Apply layer normalization and dropout
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
combined_embeddings = self.embedding_dropout(combined_embeddings)
combined_embeddings = self.final_layer_norm(combined_embeddings)
# Forward pass through transformer
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate loss if labels provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Get normalized weights for logging
norm_weights_cpu = norm_weights.detach().cpu()
return {
"loss": masked_lm_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions,
"embedding_weights": {
"token": norm_weights_cpu[0].item(),
"tree": norm_weights_cpu[1].item(),
"sequential": norm_weights_cpu[2].item()
}
}
def main() -> None:
set_deterministic_mode()
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name='tree')
config = load_config(current_dir / 'tmp_config.json')
setup_wandb(config, model_name='tree')
set_seed(config['seed'])
device = setup_device()
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Create dataloaders with tree dataset
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
device,
dataset_class=PreprocessedTreeDataset
)
# Model setup
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = TreeCodeBERTForPreTraining(model_config)
model = model.to(device)
# Training setup and run
num_training_steps = config['epochs'] * len(train_dataloader)
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
train_and_evaluate(
model, train_dataloader, valid_dataloader,
optimizer, scheduler, config, output_dir,
log_weights=True, device=device
)
if __name__ == "__main__":
main()

306
code/src/training.py Normal file
View File

@ -0,0 +1,306 @@
import math
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Optional
from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM
from training_utils import *
MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom'
DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python'
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False):
super().__init__(config)
self.base_custom = base_custom
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
if base_custom:
initial_weight = math.log(1/2)
self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight))
else:
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None and not self.base_custom:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
if self.base_custom:
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * seq_embeddings
)
else:
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
if self.base_custom:
embedding_weights = {
"token": weights_cpu[0].item(),
"sequential": weights_cpu[1].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
else:
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
"embedding_weights": embedding_weights,
"regularization_metrics": reg_metrics
}
def main() -> None:
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name=MODEL)
config = load_config(current_dir / 'config.json')
set_deterministic_mode(config['seed'])
setup_wandb(config, model_name=MODEL, script_path=__file__)
set_seed(config['seed'])
device = setup_device()
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
# Create dataloaders with tree dataset
logger.info("Creating dataloaders...")
processed_data_dir = current_dir.parent / 'data' / DATASET
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
base_training=True if MODEL == 'base' else False
)
# Initialize model config
logger.info("Initializing model...")
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
# Update W&B
wandb.config.update({'model_config': model_config.__dict__})
# Initialize model
if MODEL == 'tree':
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512)
elif MODEL == 'base':
model = RobertaForMaskedLM(model_config)
elif MODEL == 'base-custom':
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True)
else:
raise ValueError(f"Invalid model type: {MODEL}")
model = model.to(device)
logging.info(f"Model initialized: {MODEL}")
# Model precision
if device.type == 'cuda' and config['fp16']:
model = model.half()
logging.info("Model set to half precision.")
for param in model.parameters():
logging.info(f"Parameter dtype: {param.dtype}")
break
# Training setup
logger.info("Setting up optimizer and scheduler...")
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
# Train and evaluate model
logger.info("Starting training...")
train_and_evaluate(
model=model,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
optimizer=optimizer,
scheduler=scheduler,
config=config,
device=device,
output_dir=output_dir
)
logger.info("Training completed!")
final_output_dir = output_dir / 'final-model'
model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
logger.info(f"Model saved to {final_output_dir}")
if __name__ == "__main__":
main()

View File

@ -8,38 +8,41 @@ import wandb
import numpy as np
import torch
from pathlib import Path
from typing import Dict, Any, Tuple, Type, Optional
from datasets import load_from_disk, concatenate_datasets
from typing import Dict, Any, Tuple, Optional
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader
from transformers import (
PreTrainedModel,
get_linear_schedule_with_warmup,
get_constant_schedule_with_warmup,
DataCollatorForLanguageModeling,
PreTrainedTokenizer
PreTrainedTokenizer,
RobertaForMaskedLM
)
from tqdm import tqdm
from parse_dataset import ASTNodeInfo
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def set_deterministic_mode() -> None:
def set_deterministic_mode(seed: int) -> None:
"""Enable deterministic mode for reproducibility."""
# Set Python random seed
random.seed(42)
random.seed(seed)
# Set NumPy random seed
np.random.seed(42)
np.random.seed(seed)
# Set PyTorch random seed
torch.manual_seed(42)
torch.manual_seed(seed)
# Set CUDA random seed
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
torch.cuda.manual_seed_all(seed)
# Enable deterministic operations for PyTorch 2.x
torch.use_deterministic_algorithms(True)
@ -47,7 +50,7 @@ def set_deterministic_mode() -> None:
torch.backends.cudnn.benchmark = False
# Set environment variables for reproducibility
os.environ['PYTHONHASHSEED'] = '42'
os.environ['PYTHONHASHSEED'] = f'{seed}'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
def get_device_info() -> Dict[str, Any]:
@ -70,53 +73,6 @@ def get_device_info() -> Dict[str, Any]:
return info
class TreeDataCollator(DataCollatorForLanguageModeling):
"""Custom data collator that handles both MLM and optional AST node information."""
def torch_call(self, examples):
# Check if we have AST nodes
has_ast = 'ast_nodes' in examples[0]
# Extract AST nodes before MLM processing if they exist
ast_nodes = None
if has_ast:
ast_nodes = [e.pop('ast_nodes') for e in examples]
# Process normal MLM features
batch = super().torch_call(examples)
# Add AST nodes back to batch if they existed
if has_ast:
batch['ast_nodes'] = ast_nodes
return batch
def __call__(self, examples):
return self.torch_call(examples)
class PreprocessedBaseDataset(Dataset):
"""Dataset that uses pre-processed data without AST information."""
def __init__(self, dataset: Any):
self.dataset = dataset
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
return {
'input_ids': torch.tensor(item['input_ids']),
'attention_mask': torch.tensor(item['attention_mask']),
'labels': torch.tensor(item['input_ids']).clone()
}
class PreprocessedTreeDataset(PreprocessedBaseDataset):
"""Dataset that includes AST information."""
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']]
return item
def set_seed(seed: int) -> None:
"""Set random seeds for reproducibility."""
random.seed(seed)
@ -128,6 +84,7 @@ def set_seed(seed: int) -> None:
def setup_wandb(
config: Dict[str, Any],
model_name: str = 'base',
script_path: Path = None,
device_info: Optional[Dict[str, Any]] = None
) -> None:
"""Initialize W&B logging with reproducibility information."""
@ -136,17 +93,21 @@ def setup_wandb(
# Add reproducibility settings to config
full_config = {
**config,
'random_seed': 42,
'deterministic_mode': True,
'device_info': device_info or get_device_info(),
}
wandb.init(
project='codebert-training-test',
project='new-codebert-tree-training',
name=f'{model_name}_{curr_time}',
config=full_config
)
# Upload the training script
wandb.save(script_path)
wandb.save(Path(script_path).parent / 'training_utils.py')
logger.info(f'Saving script {script_path} to W&B')
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
"""Create output directories for model artifacts."""
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
@ -172,7 +133,6 @@ def setup_device() -> torch.device:
def create_optimizer_and_scheduler(
model: PreTrainedModel,
config: Dict[str, Any],
num_training_steps: int
) -> Tuple[AdamW, Any]:
"""Create optimizer and learning rate scheduler."""
optimizer = AdamW(
@ -181,270 +141,215 @@ def create_optimizer_and_scheduler(
weight_decay=config['weight_decay']
)
scheduler = get_linear_schedule_with_warmup(
scheduler = get_constant_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
)
return optimizer, scheduler
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
"""Evaluate model on validation set."""
def evaluate(
model: RobertaForMaskedLM,
dataloader: DataLoader,
device: torch.device
) -> Tuple[float, float]:
"""Evaluate the model on the validation set.
Returns:
Tuple of (average loss, accuracy on masked tokens)
"""
model.eval()
total_loss: float = 0.0
total_acc: float = 0.0
total_loss = 0
total_correct = 0
total_predictions = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc='Validation'):
for batch in tqdm(dataloader, desc='Evaluating'):
batch = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
outputs = model(**batch)
total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item()
logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits']
total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item()
avg_loss: float = total_loss / len(dataloader)
avg_acc: float = total_acc / len(dataloader.dataset)
model.train()
return avg_loss, avg_acc
# Get loss
total_loss += outputs['loss'].item()
# Calculate accuracy only on masked tokens
predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length]
labels = batch['labels'] # [batch_size, seq_length]
# Create mask for tokens we actually want to predict (ignore padding and unmasked tokens)
predict_mask = labels != -100 # -100 is the ignore index
# Calculate accuracy
correct = predictions[predict_mask] == labels[predict_mask]
total_correct += correct.sum().item()
total_predictions += predict_mask.sum().item()
avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
return avg_loss, accuracy
def train_and_evaluate(
model: PreTrainedModel,
model: RobertaForMaskedLM,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: AdamW,
scheduler: Any,
config: Dict[str, Any],
device: torch.device,
output_dir: Path,
log_weights: bool = False,
device: torch.device = torch.device('cpu')
) -> None:
"""Train and evaluate model with deterministic behavior and comprehensive logging."""
# Enable deterministic algorithms for PyTorch 2.5
torch.use_deterministic_algorithms(True)
"""Train and evaluate the model."""
num_training_steps: int = config['epochs'] * len(train_dataloader)
best_valid_loss: float = float('inf')
# Save initial model state for reproducibility
torch.save({
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
}, output_dir / 'initial_state.pt')
with tqdm(total=num_training_steps, desc='Training') as pbar:
for epoch_idx in range(config['epochs']):
model.train()
epoch_loss = 0.0
epoch_steps = 0
best_valid_loss = float('inf')
for train_idx, train_batch in enumerate(train_dataloader):
for epoch in range(config['epochs']):
total_loss = 0
# Training loop
with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
for step, batch in enumerate(pbar):
# Move batch to device
train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in train_batch.items()}
batch = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
# Forward pass
outputs = model(**train_batch)
train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss']
outputs = model(**batch)
loss = outputs['loss']
# Backward pass with deterministic behavior
train_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
config['max_grad_norm']
)
# Backward pass
loss.backward()
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=False) # Use False for determinism
optimizer.zero_grad()
# Update metrics
current_loss = train_loss.item()
epoch_loss += current_loss
epoch_steps += 1
# Calculate global step
step = train_idx + len(train_dataloader) * epoch_idx
# Prepare logging dictionary
log_dict = {
'step': step,
'epoch': epoch_idx,
'train_loss': current_loss,
'gradient_norm': grad_norm.item(),
'learning_rate': scheduler.get_last_lr()[0],
}
# Add embedding weights if using tree model
if log_weights and 'embedding_weights' in outputs:
weights = outputs['embedding_weights']
log_dict.update({
'token_weight': weights['token'],
'tree_weight': weights['tree'],
'sequential_weight': weights['sequential']
})
pbar_dict = {
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
'train_loss': f"{current_loss:.3f}",
'α': f"{weights['token']:.2f}",
'β': f"{weights['tree']:.2f}",
'γ': f"{weights['sequential']:.2f}"
}
else:
pbar_dict = {
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
'train_loss': f"{current_loss:.3f}"
}
# Log to wandb
wandb.log(log_dict)
total_loss += loss.item()
# Update progress bar
pbar.update(1)
pbar.set_postfix(pbar_dict)
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
# Periodic evaluation
if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1:
model.eval()
valid_loss, valid_acc = evaluate(model, valid_dataloader)
model.train()
# Log to wandb if configured
log_values = {
'train_loss': loss.item(),
'grad_norm': grad_norm.item(),
'learning_rate': scheduler.get_last_lr()[0]
}
if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']:
log_values.update({
'token_weight': outputs['embedding_weights']['token'],
'tree_weight': outputs['embedding_weights']['tree'],
'seq_weight': outputs['embedding_weights']['sequential'],
})
elif 'embedding_weights' in outputs:
log_values.update({
'token_weight': outputs['embedding_weights']['token'],
'seq_weight': outputs['embedding_weights']['sequential'],
})
if 'regularization_metrics' in outputs:
log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()})
wandb.log(log_values)
# Evaluate periodically
if step % config['eval_every'] == 0 and step > 0:
valid_loss, valid_acc = evaluate(model, valid_dataloader, device)
# Log validation metrics
wandb.log({
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': step,
'epoch': epoch_idx,
})
wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc})
# Save checkpoint if best model
# Print validation metrics
print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
# Save best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}')
# Save complete state
torch.save({
'epoch': epoch_idx,
'step': step,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'best_valid_loss': best_valid_loss,
'config': config,
'device_info': get_device_info()
}, output_dir / 'best_model.pt')
# Log best metrics
wandb.run.summary['best_valid_loss'] = best_valid_loss
wandb.run.summary['best_valid_acc'] = valid_acc
wandb.run.summary['best_epoch'] = epoch_idx
wandb.run.summary['best_step'] = step
# End of epoch logging
avg_epoch_loss = epoch_loss / epoch_steps
wandb.log({
'epoch': epoch_idx,
'epoch_avg_loss': avg_epoch_loss,
})
# Save end of epoch checkpoint
torch.save({
'epoch': epoch_idx,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'train_loss': avg_epoch_loss,
'config': config,
'device_info': get_device_info()
}, output_dir / f'checkpoint_epoch_{epoch_idx}.pt')
# End of training logging
wandb.run.summary['final_epoch'] = config['epochs'] - 1
wandb.run.summary['total_steps'] = num_training_steps
logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}')
logger.info(f'Model checkpoints saved in {output_dir}')
def load_all_chunks(chunks_dir: Path) -> Any:
"""Load and concatenate all dataset chunks."""
logger.info(f"Loading dataset chunks from {chunks_dir}")
chunks = []
for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]:
chunks.append(load_from_disk(str(chunk_path)))
dataset = concatenate_datasets(chunks)
logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks")
return dataset
model.train() # Resume training mode
def create_base_dataloaders(
processed_data_dir: Path,
tokenizer: PreTrainedTokenizer,
config: Dict[str, Any],
device: torch.device,
dataset_class: Type[Dataset] = PreprocessedBaseDataset,
base_training=False,
) -> Tuple[DataLoader, DataLoader]:
"""Create reproducible dataloaders from pre-processed data."""
# Load chunks
chunks_dir = processed_data_dir / 'chunks'
full_dataset = load_all_chunks(chunks_dir)
"""Create dataloaders from pre-processed parquet data."""
# Load chunks using datasets library
chunks_dir = processed_data_dir / 'data'
dataset = load_dataset(
"parquet",
data_files=str(chunks_dir / "*.parquet"),
split="train"
)
contents = dataset['content'][:config['batch_size']]
with open("contents.txt", "w") as f:
for content in contents:
f.write(content)
f.write("\n\n\n")
f.write("#" * 80)
f.write("\n\n\n")
# Remove columns that are not needed
columns_to_remove = dataset.column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if not base_training:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove)
logging.info(f"Loaded dataset:\n{dataset}")
# Calculate split sizes
dataset_size = len(full_dataset)
dataset_size = len(dataset)
train_size = int(config['train_size'] * dataset_size)
val_size = dataset_size - train_size
# Create splits using indices to avoid device issues
indices = torch.arange(dataset_size, device=device)
# Create splits
splits = dataset.train_test_split(
train_size=train_size,
test_size=val_size,
seed=config['seed']
)
train_dataset = splits['train']
valid_dataset = splits['test']
# Create a deterministic generator on the correct device
generator = torch.Generator(device=device)
generator.manual_seed(42)
# Shuffle indices deterministically
shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)]
train_indices = shuffled_indices[:train_size].cpu()
valid_indices = shuffled_indices[train_size:].cpu()
# Create train/validation splits using the shuffled indices
train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist())
valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist())
# Create dataset wrappers
train_dataset = dataset_class(train_dataset)
valid_dataset = dataset_class(valid_dataset)
logger.info(f"Created train dataset with {len(train_dataset)} samples")
logger.info(f"Created validation dataset with {len(valid_dataset)} samples")
# Use TreeDataCollator for both base and tree models
data_collator = TreeDataCollator(
# Create data collator for MLM
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config['mlm_probability']
)
# Create train dataloader without generator
train_dataloader = DataLoader(
train_dataset,
batch_size=config['batch'],
shuffle=False, # We already shuffled the data
batch_size=config['batch_size'],
shuffle=False, # We'll let the dataset handle shuffling
collate_fn=data_collator,
num_workers=0, # Use single worker for reproducibility
drop_last=True # Ensure consistent batch sizes
num_workers=0,
drop_last=True,
)
# Create validation dataloader
valid_dataloader = DataLoader(
valid_dataset,
batch_size=config['batch'],
batch_size=config['batch_size'],
shuffle=False,
collate_fn=data_collator,
num_workers=0,
drop_last=True
drop_last=True,
)
return train_dataloader, valid_dataloader