version that works on CodeSearchNet
This commit is contained in:
parent
b2d65059da
commit
1d8ce0c72f
@ -39,4 +39,4 @@ distribution = true
|
||||
|
||||
[tool.pdm.scripts]
|
||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||
run_training = {cmd = "src/training.py"}
|
||||
train = {cmd = "src/training.py"}
|
||||
|
@ -1,4 +1,8 @@
|
||||
{
|
||||
"extra_embeddings": true,
|
||||
"run_name": "tree",
|
||||
"data_dir": "./data/CodeSearchNet-parsed/python/",
|
||||
"output_dir": "./outputs/tree",
|
||||
"seed": 420,
|
||||
"mlm_probability": 0.15,
|
||||
"batch_size": 32,
|
||||
@ -8,6 +12,6 @@
|
||||
"weight_decay": 0.1,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 1000,
|
||||
"train_size": 0.95,
|
||||
"fp16": true
|
||||
"fp16": true,
|
||||
"logging_steps": 100
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,22 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def download_dataset(dataset_dir: Path) -> None:
|
||||
if not dataset_dir.exists():
|
||||
logger.info("Downloading the dataset...")
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
|
||||
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
|
||||
for file_name in files_to_download:
|
||||
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
|
||||
logger.info("Dataset downloaded successfully.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_dir = Path('data/the-stack-python')
|
||||
download_dataset(dataset_dir)
|
@ -1,11 +1,9 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import tree_sitter_python as tspython
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Dict, Any, List
|
||||
from datasets import load_dataset, Dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from tree_sitter import Language, Parser
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -82,7 +80,7 @@ def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: Auto
|
||||
|
||||
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
||||
"""Process a batch of examples."""
|
||||
contents = examples['content']
|
||||
contents = examples['code']
|
||||
|
||||
processed_input_ids = []
|
||||
processed_attention_mask = []
|
||||
@ -120,97 +118,70 @@ def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
||||
'node_texts': processed_node_texts,
|
||||
}
|
||||
|
||||
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
|
||||
"""Save dataset to disk in chunks to manage memory usage."""
|
||||
num_chunks = (len(dataset) + chunk_size - 1) // chunk_size
|
||||
|
||||
# Create directory for chunks
|
||||
output_dir = Path(output_path).parent
|
||||
chunks_dir = output_dir / 'data'
|
||||
chunks_dir.mkdir(exist_ok=True)
|
||||
|
||||
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
|
||||
for i in pbar:
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min((i + 1) * chunk_size, len(dataset))
|
||||
|
||||
# Select chunk from dataset
|
||||
chunk_dataset = dataset.select(range(start_idx, end_idx))
|
||||
|
||||
if len(chunk_dataset) > 0: # Only process if chunk has data
|
||||
# Save chunk using datasets native method
|
||||
chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet'
|
||||
chunk_dataset.to_parquet(str(chunk_path))
|
||||
|
||||
def main():
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'the-stack-python'
|
||||
output_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
input_dir = current_dir.parent / 'data' / 'CodeSearchNet' / 'python'
|
||||
output_dir = current_dir.parent / 'data' / 'CodeSearchNet-parsed' / 'python'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
|
||||
|
||||
logger.info("Loading dataset...")
|
||||
dataset = load_dataset(str(input_dir))['train']
|
||||
# dataset = dataset.select(range(200_000)) # Limit dataset size for testing
|
||||
|
||||
original_dataset_size = len(dataset)
|
||||
logger.info(f"Original dataset size: {original_dataset_size}")
|
||||
# Initialize tokenizer and model from scratch
|
||||
logger.info("Initializing tokenizer and model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
logger.info("Loaded CodeBERT tokenizer")
|
||||
|
||||
columns_to_remove = dataset.column_names
|
||||
columns_to_remove.remove('content')
|
||||
columns_to_remove.remove('size')
|
||||
columns_to_remove.remove('avg_line_length')
|
||||
columns_to_remove.remove('max_line_length')
|
||||
columns_to_remove.remove('alphanum_fraction')
|
||||
columns_to_remove.remove('max_stars_count')
|
||||
logging.info(f"Columns to remove: {columns_to_remove}")
|
||||
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
|
||||
logger.info("Processing dataset...")
|
||||
processed_dataset = dataset.map(
|
||||
process_batch,
|
||||
fn_kwargs={'tokenizer': tokenizer},
|
||||
batched=True,
|
||||
remove_columns=columns_to_remove,
|
||||
desc="Processing dataset",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = processed_dataset.filter(
|
||||
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
desc="Filtering invalid examples"
|
||||
)
|
||||
|
||||
reduced_dataset_size = len(processed_dataset)
|
||||
logger.info(f"Processed dataset size: {reduced_dataset_size}")
|
||||
|
||||
if reduced_dataset_size == 0:
|
||||
logger.error("No valid examples found in dataset!")
|
||||
return
|
||||
|
||||
logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...")
|
||||
save_dataset_in_chunks(
|
||||
processed_dataset,
|
||||
str(output_dir / 'processed_dataset'),
|
||||
chunk_size=100_000
|
||||
)
|
||||
|
||||
# Add derived statistics
|
||||
stats = {
|
||||
'original_dataset_size': original_dataset_size,
|
||||
'reduced_dataset_size': reduced_dataset_size,
|
||||
'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100,
|
||||
}
|
||||
|
||||
# Log stats with pprint
|
||||
logger.info("Processing completed! Stats:")
|
||||
pprint(stats)
|
||||
for dataset_name in ['valid', 'test', 'train']:
|
||||
logger.info(f"Processing dataset: {dataset_name}")
|
||||
|
||||
input_file = input_dir / f'{dataset_name}.jsonl'
|
||||
dataset = load_dataset('json', data_files=str(input_file))['train']
|
||||
original_dataset_size = len(dataset)
|
||||
logger.info(f"Loaded dataset from {input_file} with {original_dataset_size} examples")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples['code'],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_special_tokens_mask=True
|
||||
)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
desc="Tokenizing",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = tokenized_dataset.map(
|
||||
process_batch,
|
||||
fn_kwargs={'tokenizer': tokenizer},
|
||||
batched=True,
|
||||
desc="Processing",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = processed_dataset.filter(
|
||||
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
desc="Filtering invalid examples"
|
||||
)
|
||||
|
||||
if len(processed_dataset) == 0:
|
||||
logger.error("No valid examples found in dataset!")
|
||||
return
|
||||
|
||||
output_file = output_dir / f'{dataset_name}'
|
||||
processed_dataset.save_to_disk(str(output_file))
|
||||
logger.info(f"Saved processed dataset to {output_file} with {len(processed_dataset)} examples")
|
||||
|
||||
logger.info("All datasets processed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,50 +0,0 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
|
||||
from parse_dataset import save_dataset_in_chunks
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
output_dir = current_dir.parent / 'data' / 'filtered-python'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
|
||||
logger.info("Loading dataset...")
|
||||
dataset = load_dataset(str(input_dir))['data']
|
||||
logging.info(f"Dataset:\n{dataset}")
|
||||
|
||||
logger.info("Filtering dataset by max_stars_count > 3...")
|
||||
filtered_dataset = dataset.filter(
|
||||
lambda batch: [example and example > 3 for example in batch['max_stars_count']],
|
||||
num_proc=num_proc,
|
||||
batched=True,
|
||||
desc="Filtering dataset"
|
||||
)
|
||||
|
||||
filtered_dataset_size = len(filtered_dataset)
|
||||
logger.info(f"Filtered dataset size: {filtered_dataset_size}")
|
||||
|
||||
if filtered_dataset_size == 0:
|
||||
logger.error("No examples found with max_stars_count > 3!")
|
||||
return
|
||||
|
||||
logger.info(f"Saving {filtered_dataset_size} filtered examples...")
|
||||
save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000)
|
||||
|
||||
logger.info("Filtering and saving completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,306 +1,121 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM
|
||||
from datasets import load_from_disk, DatasetDict
|
||||
from transformers import (
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
AutoTokenizer,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForLanguageModeling
|
||||
)
|
||||
|
||||
from training_utils import *
|
||||
from tree_model import TreeCodeBERTForPreTraining
|
||||
|
||||
MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom'
|
||||
DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python'
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
||||
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
|
||||
# Separate embeddings for different features
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
|
||||
# Improved projection layers
|
||||
self.node_projection = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# Layer norm for stability
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
std = 0.02
|
||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||
|
||||
# Initialize projection layers
|
||||
for layer in self.node_projection:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
||||
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
||||
Returns:
|
||||
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
||||
"""
|
||||
# Clamp values to max_depth
|
||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||
|
||||
# Get embeddings for each feature
|
||||
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
||||
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
||||
|
||||
# Combine features
|
||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||
embeddings = self.node_projection(combined)
|
||||
|
||||
# Apply layer norm
|
||||
normalized_embeddings = self.layer_norm(embeddings)
|
||||
|
||||
return normalized_embeddings
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with tree-structural information."""
|
||||
|
||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False):
|
||||
super().__init__(config)
|
||||
self.base_custom = base_custom
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth,
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# Initialize embedding weights equally
|
||||
if base_custom:
|
||||
initial_weight = math.log(1/2)
|
||||
self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight))
|
||||
else:
|
||||
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
||||
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
||||
|
||||
# Layer norms for embedding combination
|
||||
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
def load_config(config_path: Path) -> dict:
|
||||
with open(config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_normalized_weights(self):
|
||||
"""Get softmaxed weights for embedding combination."""
|
||||
return torch.softmax(self.embedding_weights, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
sibling_idxs: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
device = input_ids.device
|
||||
|
||||
# Get normalized weights for embedding combination and calculate regularization
|
||||
weights = self.get_normalized_weights()
|
||||
|
||||
# Calculate weight variance regularization
|
||||
# We want weights to remain somewhat balanced, so penalize high variance
|
||||
weight_variance = torch.var(weights)
|
||||
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
||||
|
||||
# Add L2 regularization to prevent any weight from getting too close to 1
|
||||
# This helps maintain a more balanced contribution from each embedding type
|
||||
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
||||
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
||||
|
||||
# Get token embeddings
|
||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||
token_embeddings = self.pre_combination_norm(token_embeddings)
|
||||
|
||||
# Get sequential position embeddings
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get tree positional embeddings if tree information is provided
|
||||
if depths is not None and sibling_idxs is not None and not self.base_custom:
|
||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
else:
|
||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||
|
||||
# Combine all embeddings using learned weights
|
||||
if self.base_custom:
|
||||
combined_embeddings = (
|
||||
weights[0] * token_embeddings +
|
||||
weights[1] * seq_embeddings
|
||||
)
|
||||
else:
|
||||
combined_embeddings = (
|
||||
weights[0] * token_embeddings +
|
||||
weights[1] * tree_embeddings +
|
||||
weights[2] * seq_embeddings
|
||||
)
|
||||
|
||||
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
||||
|
||||
# Forward pass through base model
|
||||
outputs = self.roberta(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
# Calculate MLM loss if labels are provided
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
# Add regularization losses to final loss
|
||||
if masked_lm_loss is not None:
|
||||
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = None
|
||||
|
||||
# Prepare embedding weights for logging
|
||||
weights_cpu = weights.detach().cpu()
|
||||
if self.base_custom:
|
||||
embedding_weights = {
|
||||
"token": weights_cpu[0].item(),
|
||||
"sequential": weights_cpu[1].item()
|
||||
}
|
||||
reg_metrics = {
|
||||
"weight_variance": weight_variance.item(),
|
||||
"max_weight_penalty": max_weight_penalty.item(),
|
||||
"weight_reg_loss": weight_reg_loss.item(),
|
||||
"l2_reg_loss": l2_reg_loss.item()
|
||||
}
|
||||
else:
|
||||
embedding_weights = {
|
||||
"token": weights_cpu[0].item(),
|
||||
"tree": weights_cpu[1].item(),
|
||||
"sequential": weights_cpu[2].item()
|
||||
}
|
||||
reg_metrics = {
|
||||
"weight_variance": weight_variance.item(),
|
||||
"max_weight_penalty": max_weight_penalty.item(),
|
||||
"weight_reg_loss": weight_reg_loss.item(),
|
||||
"l2_reg_loss": l2_reg_loss.item()
|
||||
}
|
||||
|
||||
return {
|
||||
"loss": final_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions if output_attentions else None,
|
||||
"embedding_weights": embedding_weights,
|
||||
"regularization_metrics": reg_metrics
|
||||
}
|
||||
|
||||
def main() -> None:
|
||||
def main():
|
||||
# Setup paths
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir, model_name=MODEL)
|
||||
config = load_config(current_dir / 'config.json')
|
||||
data_dir = Path(config['data_dir'])
|
||||
output_dir = Path(config['output_dir'])
|
||||
|
||||
set_deterministic_mode(config['seed'])
|
||||
|
||||
setup_wandb(config, model_name=MODEL, script_path=__file__)
|
||||
set_seed(config['seed'])
|
||||
device = setup_device()
|
||||
# Initialize W&B
|
||||
wandb.init(project='easy_training', config=config, name=config['run_name'])
|
||||
|
||||
# Load tokenizer
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
|
||||
# Upload the training files to W&B
|
||||
wandb.save(__file__)
|
||||
wandb.save(Path(__file__).parent / 'config.json')
|
||||
if config['extra_embeddings']:
|
||||
wandb.save(current_dir / 'tree_model.py')
|
||||
|
||||
# Create dataloaders with tree dataset
|
||||
logger.info("Creating dataloaders...")
|
||||
processed_data_dir = current_dir.parent / 'data' / DATASET
|
||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||
processed_data_dir,
|
||||
tokenizer,
|
||||
config,
|
||||
base_training=True if MODEL == 'base' else False
|
||||
)
|
||||
dataset = DatasetDict({
|
||||
'train': load_from_disk(data_dir / 'train'),
|
||||
'valid': load_from_disk(data_dir / 'valid'),
|
||||
'test': load_from_disk(data_dir / 'test')
|
||||
})
|
||||
|
||||
# Initialize model config
|
||||
logger.info("Initializing model...")
|
||||
# Continue with the rest of processing
|
||||
columns_to_remove = dataset['train'].column_names
|
||||
columns_to_remove.remove('input_ids')
|
||||
columns_to_remove.remove('attention_mask')
|
||||
if config['extra_embeddings']:
|
||||
columns_to_remove.remove('depths')
|
||||
columns_to_remove.remove('sibling_idxs')
|
||||
dataset = dataset.remove_columns(columns_to_remove)
|
||||
logger.info(f'Loaded dataset:\n{dataset}')
|
||||
|
||||
# Initialize model from scratch
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
# Update W&B
|
||||
wandb.config.update({'model_config': model_config.__dict__})
|
||||
|
||||
# Initialize model
|
||||
if MODEL == 'tree':
|
||||
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512)
|
||||
elif MODEL == 'base':
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
elif MODEL == 'base-custom':
|
||||
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True)
|
||||
if config['extra_embeddings']:
|
||||
model = TreeCodeBERTForPreTraining(model_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {MODEL}")
|
||||
model = model.to(device)
|
||||
logging.info(f"Model initialized: {MODEL}")
|
||||
|
||||
# Model precision
|
||||
if device.type == 'cuda' and config['fp16']:
|
||||
model = model.half()
|
||||
logging.info("Model set to half precision.")
|
||||
for param in model.parameters():
|
||||
logging.info(f"Parameter dtype: {param.dtype}")
|
||||
break
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
logger.info(f'Loaded model: {model.__class__.__name__}')
|
||||
|
||||
# Training setup
|
||||
logger.info("Setting up optimizer and scheduler...")
|
||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
|
||||
|
||||
# Train and evaluate model
|
||||
logger.info("Starting training...")
|
||||
train_and_evaluate(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
config=config,
|
||||
device=device,
|
||||
output_dir=output_dir
|
||||
# Setup training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir),
|
||||
per_device_train_batch_size=config['batch_size'],
|
||||
per_device_eval_batch_size=config['batch_size'],
|
||||
learning_rate=config['learning_rate'],
|
||||
weight_decay=config['weight_decay'],
|
||||
num_train_epochs=config['epochs'],
|
||||
warmup_steps=config['warmup_steps'],
|
||||
max_grad_norm=config['max_grad_norm'],
|
||||
logging_steps=config['logging_steps'],
|
||||
eval_steps=config['eval_every'],
|
||||
save_steps=config['eval_every'],
|
||||
eval_strategy='steps',
|
||||
save_strategy='steps',
|
||||
load_best_model_at_end=True,
|
||||
report_to='wandb',
|
||||
run_name=config['run_name'],
|
||||
seed=config['seed'],
|
||||
fp16=config['fp16'],
|
||||
dataloader_num_workers=8,
|
||||
gradient_checkpointing=True,
|
||||
metric_for_best_model='eval_loss',
|
||||
greater_is_better=False,
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset['train'],
|
||||
eval_dataset=dataset['valid'],
|
||||
data_collator=DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=config['mlm_probability']
|
||||
),
|
||||
)
|
||||
|
||||
# Train
|
||||
logger.info('Starting training...')
|
||||
trainer.train()
|
||||
|
||||
logger.info("Training completed!")
|
||||
final_output_dir = output_dir / 'final-model'
|
||||
model.save_pretrained(final_output_dir)
|
||||
tokenizer.save_pretrained(final_output_dir)
|
||||
logger.info(f"Model saved to {final_output_dir}")
|
||||
# Evaluate
|
||||
logger.info('Evaluating on test set...')
|
||||
eval_results = trainer.evaluate(eval_dataset=dataset['test'])
|
||||
logger.info(eval_results)
|
||||
|
||||
# Save final model
|
||||
logger.info('Saving final model...')
|
||||
trainer.save_model(output_dir / 'final-model')
|
||||
tokenizer.save_pretrained(output_dir / 'final-model')
|
||||
|
||||
logger.info('Training completed!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,355 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import datetime
|
||||
import platform
|
||||
import logging
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from datasets import load_dataset
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
get_constant_schedule_with_warmup,
|
||||
DataCollatorForLanguageModeling,
|
||||
PreTrainedTokenizer,
|
||||
RobertaForMaskedLM
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_deterministic_mode(seed: int) -> None:
|
||||
"""Enable deterministic mode for reproducibility."""
|
||||
# Set Python random seed
|
||||
random.seed(seed)
|
||||
|
||||
# Set NumPy random seed
|
||||
np.random.seed(seed)
|
||||
|
||||
# Set PyTorch random seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Set CUDA random seed
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Enable deterministic operations for PyTorch 2.x
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Set environment variables for reproducibility
|
||||
os.environ['PYTHONHASHSEED'] = f'{seed}'
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
|
||||
|
||||
def get_device_info() -> Dict[str, Any]:
|
||||
"""Get detailed information about the computing environment."""
|
||||
info = {
|
||||
'python_version': platform.python_version(),
|
||||
'torch_version': torch.__version__,
|
||||
'cuda_available': torch.cuda.is_available(),
|
||||
'deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
|
||||
'cudnn_deterministic': torch.backends.cudnn.deterministic,
|
||||
'cudnn_benchmark': torch.backends.cudnn.benchmark,
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
info.update({
|
||||
'cuda_version': torch.version.cuda,
|
||||
'gpu_name': torch.cuda.get_device_name(0),
|
||||
'gpu_count': torch.cuda.device_count(),
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def setup_wandb(
|
||||
config: Dict[str, Any],
|
||||
model_name: str = 'base',
|
||||
script_path: Path = None,
|
||||
device_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Initialize W&B logging with reproducibility information."""
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
||||
|
||||
# Add reproducibility settings to config
|
||||
full_config = {
|
||||
**config,
|
||||
'deterministic_mode': True,
|
||||
'device_info': device_info or get_device_info(),
|
||||
}
|
||||
|
||||
wandb.init(
|
||||
project='new-codebert-tree-training',
|
||||
name=f'{model_name}_{curr_time}',
|
||||
config=full_config
|
||||
)
|
||||
|
||||
# Upload the training script
|
||||
wandb.save(script_path)
|
||||
wandb.save(Path(script_path).parent / 'training_utils.py')
|
||||
logger.info(f'Saving script {script_path} to W&B')
|
||||
|
||||
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
|
||||
"""Create output directories for model artifacts."""
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
||||
output_dir: Path = current_dir.parent.parent / 'outputs' / f'{model_name}_{curr_time}'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
def load_config(config_file: Path) -> Dict[str, Any]:
|
||||
"""Load training configuration from JSON file."""
|
||||
with open(config_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def setup_device() -> torch.device:
|
||||
"""Setup and configure training device."""
|
||||
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
torch.set_default_device(device)
|
||||
logger.info(f'Using device: {device}')
|
||||
if device.type == 'cuda':
|
||||
logger.info(f'Device name: {torch.cuda.get_device_name()}')
|
||||
torch.set_float32_matmul_precision('high')
|
||||
return device
|
||||
|
||||
def create_optimizer_and_scheduler(
|
||||
model: PreTrainedModel,
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[AdamW, Any]:
|
||||
"""Create optimizer and learning rate scheduler."""
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def evaluate(
|
||||
model: RobertaForMaskedLM,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device
|
||||
) -> Tuple[float, float]:
|
||||
"""Evaluate the model on the validation set.
|
||||
|
||||
Returns:
|
||||
Tuple of (average loss, accuracy on masked tokens)
|
||||
"""
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
total_correct = 0
|
||||
total_predictions = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc='Evaluating'):
|
||||
batch = {
|
||||
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()
|
||||
}
|
||||
outputs = model(**batch)
|
||||
|
||||
# Get loss
|
||||
total_loss += outputs['loss'].item()
|
||||
|
||||
# Calculate accuracy only on masked tokens
|
||||
predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length]
|
||||
labels = batch['labels'] # [batch_size, seq_length]
|
||||
|
||||
# Create mask for tokens we actually want to predict (ignore padding and unmasked tokens)
|
||||
predict_mask = labels != -100 # -100 is the ignore index
|
||||
|
||||
# Calculate accuracy
|
||||
correct = predictions[predict_mask] == labels[predict_mask]
|
||||
total_correct += correct.sum().item()
|
||||
total_predictions += predict_mask.sum().item()
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def train_and_evaluate(
|
||||
model: RobertaForMaskedLM,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device,
|
||||
output_dir: Path,
|
||||
) -> None:
|
||||
"""Train and evaluate the model."""
|
||||
|
||||
model.train()
|
||||
best_valid_loss = float('inf')
|
||||
|
||||
for epoch in range(config['epochs']):
|
||||
total_loss = 0
|
||||
|
||||
# Training loop
|
||||
with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
|
||||
for step, batch in enumerate(pbar):
|
||||
# Move batch to device
|
||||
batch = {
|
||||
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Clip gradients
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Update metrics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
||||
|
||||
# Log to wandb if configured
|
||||
log_values = {
|
||||
'train_loss': loss.item(),
|
||||
'grad_norm': grad_norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0]
|
||||
}
|
||||
if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']:
|
||||
log_values.update({
|
||||
'token_weight': outputs['embedding_weights']['token'],
|
||||
'tree_weight': outputs['embedding_weights']['tree'],
|
||||
'seq_weight': outputs['embedding_weights']['sequential'],
|
||||
})
|
||||
elif 'embedding_weights' in outputs:
|
||||
log_values.update({
|
||||
'token_weight': outputs['embedding_weights']['token'],
|
||||
'seq_weight': outputs['embedding_weights']['sequential'],
|
||||
})
|
||||
if 'regularization_metrics' in outputs:
|
||||
log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()})
|
||||
wandb.log(log_values)
|
||||
|
||||
# Evaluate periodically
|
||||
if step % config['eval_every'] == 0 and step > 0:
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader, device)
|
||||
|
||||
# Log validation metrics
|
||||
wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc})
|
||||
|
||||
# Print validation metrics
|
||||
print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
|
||||
|
||||
# Save best model
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}')
|
||||
|
||||
model.train() # Resume training mode
|
||||
|
||||
def create_base_dataloaders(
|
||||
processed_data_dir: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
base_training=False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""Create dataloaders from pre-processed parquet data."""
|
||||
|
||||
# Load chunks using datasets library
|
||||
chunks_dir = processed_data_dir / 'data'
|
||||
dataset = load_dataset(
|
||||
"parquet",
|
||||
data_files=str(chunks_dir / "*.parquet"),
|
||||
split="train"
|
||||
)
|
||||
|
||||
contents = dataset['content'][:config['batch_size']]
|
||||
with open("contents.txt", "w") as f:
|
||||
for content in contents:
|
||||
f.write(content)
|
||||
f.write("\n\n\n")
|
||||
f.write("#" * 80)
|
||||
f.write("\n\n\n")
|
||||
|
||||
# Remove columns that are not needed
|
||||
columns_to_remove = dataset.column_names
|
||||
columns_to_remove.remove('input_ids')
|
||||
columns_to_remove.remove('attention_mask')
|
||||
|
||||
if not base_training:
|
||||
columns_to_remove.remove('depths')
|
||||
columns_to_remove.remove('sibling_idxs')
|
||||
|
||||
dataset = dataset.remove_columns(columns_to_remove)
|
||||
|
||||
logging.info(f"Loaded dataset:\n{dataset}")
|
||||
|
||||
# Calculate split sizes
|
||||
dataset_size = len(dataset)
|
||||
train_size = int(config['train_size'] * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
# Create splits
|
||||
splits = dataset.train_test_split(
|
||||
train_size=train_size,
|
||||
test_size=val_size,
|
||||
seed=config['seed']
|
||||
)
|
||||
train_dataset = splits['train']
|
||||
valid_dataset = splits['test']
|
||||
|
||||
# Create data collator for MLM
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=config['mlm_probability']
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=False, # We'll let the dataset handle shuffling
|
||||
collate_fn=data_collator,
|
||||
num_workers=0,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
num_workers=0,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return train_dataloader, valid_dataloader
|
208
code/src/tree_model.py
Normal file
208
code/src/tree_model.py
Normal file
@ -0,0 +1,208 @@
|
||||
import wandb
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from transformers import RobertaConfig, RobertaForMaskedLM
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
||||
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
|
||||
# Separate embeddings for different features
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
|
||||
# Improved projection layers
|
||||
self.node_projection = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# Layer norm for stability
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
std = 0.02
|
||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||
|
||||
# Initialize projection layers
|
||||
for layer in self.node_projection:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
||||
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
||||
Returns:
|
||||
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
||||
"""
|
||||
# Clamp values to max_depth
|
||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||
|
||||
# Get embeddings for each feature
|
||||
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
||||
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
||||
|
||||
# Combine features
|
||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||
embeddings = self.node_projection(combined)
|
||||
|
||||
# Apply layer norm
|
||||
normalized_embeddings = self.layer_norm(embeddings)
|
||||
|
||||
return normalized_embeddings
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with tree-structural information."""
|
||||
|
||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
super().__init__(config)
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth,
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# Initialize embedding weights equally
|
||||
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
||||
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
||||
|
||||
# Layer norms for embedding combination
|
||||
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def get_normalized_weights(self):
|
||||
"""Get softmaxed weights for embedding combination."""
|
||||
return torch.softmax(self.embedding_weights, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
sibling_idxs: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
device = input_ids.device
|
||||
|
||||
# Get normalized weights for embedding combination and calculate regularization
|
||||
weights = self.get_normalized_weights()
|
||||
|
||||
# Calculate weight variance regularization
|
||||
# We want weights to remain somewhat balanced, so penalize high variance
|
||||
weight_variance = torch.var(weights)
|
||||
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
||||
|
||||
# Add L2 regularization to prevent any weight from getting too close to 1
|
||||
# This helps maintain a more balanced contribution from each embedding type
|
||||
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
||||
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
||||
|
||||
# Get token embeddings
|
||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||
token_embeddings = self.pre_combination_norm(token_embeddings)
|
||||
|
||||
# Get sequential position embeddings
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get tree positional embeddings if tree information is provided
|
||||
if depths is not None and sibling_idxs is not None:
|
||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
else:
|
||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||
|
||||
# Combine all embeddings using learned weights
|
||||
combined_embeddings = (
|
||||
weights[0] * token_embeddings +
|
||||
weights[1] * tree_embeddings +
|
||||
weights[2] * seq_embeddings
|
||||
)
|
||||
|
||||
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
||||
|
||||
# Forward pass through base model
|
||||
outputs = self.roberta(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
# Calculate MLM loss if labels are provided
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
# Add regularization losses to final loss
|
||||
if masked_lm_loss is not None:
|
||||
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = None
|
||||
|
||||
# Prepare embedding weights for logging
|
||||
weights_cpu = weights.detach().cpu()
|
||||
embedding_weights = {
|
||||
"token": weights_cpu[0].item(),
|
||||
"tree": weights_cpu[1].item(),
|
||||
"sequential": weights_cpu[2].item()
|
||||
}
|
||||
reg_metrics = {
|
||||
"weight_variance": weight_variance.item(),
|
||||
"max_weight_penalty": max_weight_penalty.item(),
|
||||
"weight_reg_loss": weight_reg_loss.item(),
|
||||
"l2_reg_loss": l2_reg_loss.item()
|
||||
}
|
||||
|
||||
wandb.log(
|
||||
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
|
||||
step=kwargs.get("global_step", None)
|
||||
)
|
||||
wandb.log(
|
||||
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
|
||||
step=kwargs.get("global_step", None)
|
||||
)
|
||||
|
||||
return {
|
||||
"loss": final_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions if output_attentions else None,
|
||||
}
|
Loading…
Reference in New Issue
Block a user