Compare commits

..

3 Commits

Author SHA1 Message Date
Patryk Bartkowiak
35e5d3e8fa original 2025-01-03 10:16:18 +00:00
Patryk Bartkowiak
3d6826f058 using hf to load online dataset 2025-01-03 06:12:17 +00:00
Patryk Bartkowiak
dfb1e669bd ready for runpod 2025-01-02 20:36:05 +00:00
22 changed files with 4143 additions and 401 deletions

1
code/.gitignore vendored
View File

@ -164,4 +164,3 @@ cython_debug/
# Weights & Biases
wandb/
outputs/
cache/

View File

@ -20,9 +20,9 @@ pdm install
```
### 4. Run training code
```bash
pdm train --config {CONFIG FILE}
pdm run_training
```
Example:
```bash
pdm train --config ./configs/original.yaml
or
```
pdm run src/train_codebert_mlm.py
```

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: false
wandb_project: "prof-gralinski"
training:
seed: 42
batch_size: 32
epochs: 1
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 0
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/1-fold-clone-detection-600k-5fold"
num_samples: 1000 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
path: "outputs/tmp/final-model"
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: false # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

2
code/data/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

2
code/models/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -19,7 +19,6 @@ dependencies = [
"tree-sitter-python==0.23.4",
"ipykernel==6.29.5",
"ipywidgets==8.1.5",
"pyyaml==6.0.2",
]
requires-python = "==3.11.*"
readme = "README.md"
@ -42,4 +41,3 @@ distribution = true
parse_dataset = {cmd = "src/parse_dataset.py"}
train = {cmd = "src/training.py"}
eval = {cmd = "src/eval_model.py"}
finetune = {cmd = "src_finetune/finetune.py"}

21
code/src/config.json Normal file
View File

@ -0,0 +1,21 @@
{
"project": "runpod",
"run_name": "original",
"dataset": "patrykbart/codeparrot-clean-no-comments-starencoder-small",
"output_dir": "./outputs/long-no-comments-starencoder-original",
"extra_embeddings": false,
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 192,
"epochs": 3,
"eval_every": 2500,
"learning_rate": 5e-4,
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 500,
"bf16": true,
"logging_steps": 100,
"valid_size": 0.05,
"test_size": 0.05,
"num_samples": -1
}

View File

@ -120,7 +120,7 @@ def main():
# Setup paths
current_dir = Path(__file__).parent
config = load_config(current_dir / 'eval_config.json')
model_dir = Path(config['model_dir']) / 'final-model'
model_dir = Path(config['model_dir'])
data_dir = Path(config['data_dir'])
results_dir = Path(config['model_dir']) / 'evaluation_results'
results_dir.mkdir(exist_ok=True)

3754
code/src/node_types.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -84,6 +84,7 @@ def process_example(code, tokenizer):
depths = [-1] * len(input_ids)
sibling_idxs = [-1] * len(input_ids)
node_types = [None] * len(input_ids)
node_texts = [''] * len(input_ids)
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
@ -95,6 +96,7 @@ def process_example(code, tokenizer):
if node.start_byte <= start < node.end_byte:
depths[i] = depth
sibling_idxs[i] = sibling_idx
node_types[i] = node.type
node_texts[i] = code[node.start_byte:node.end_byte]
for i, child in enumerate(node.children):
traverse(child, depth + 1, i)
@ -107,6 +109,7 @@ def process_example(code, tokenizer):
'attention_mask': attention_mask,
'depths': depths,
'sibling_idxs': sibling_idxs,
'node_types': node_types,
'node_texts': node_texts
}
@ -118,6 +121,7 @@ def process_batch(batch, tokenizer):
processed_depths = []
processed_sibling_idxs = []
processed_node_texts = []
processed_node_types = []
for content in contents:
try:
@ -130,6 +134,7 @@ def process_batch(batch, tokenizer):
processed_depths.append([])
processed_sibling_idxs.append([])
processed_node_texts.append([])
processed_node_types.append([])
else:
processed_input_ids.append(result['input_ids'])
processed_attention_mask.append(result['attention_mask'])
@ -137,6 +142,7 @@ def process_batch(batch, tokenizer):
processed_depths.append(result['depths'])
processed_sibling_idxs.append(result['sibling_idxs'])
processed_node_texts.append(result['node_texts'])
processed_node_types.append(result['node_types'])
except Exception:
# If something unexpected happens
processed_input_ids.append([])
@ -145,6 +151,7 @@ def process_batch(batch, tokenizer):
processed_depths.append([])
processed_sibling_idxs.append([])
processed_node_texts.append([])
processed_node_types.append([])
return {
'input_ids': processed_input_ids,
@ -152,6 +159,7 @@ def process_batch(batch, tokenizer):
'tokens': processed_tokens,
'depths': processed_depths,
'sibling_idxs': processed_sibling_idxs,
'node_types': processed_node_types,
'node_texts': processed_node_texts
}

View File

@ -1,13 +1,9 @@
import os
import wandb
import argparse
import yaml
import torch
import random
import json
import logging
import numpy as np
import zipfile
from pathlib import Path
from datasets import load_dataset, DatasetDict
from datasets import load_from_disk, DatasetDict, load_dataset
from transformers import (
AutoConfig,
AutoTokenizer,
@ -16,29 +12,20 @@ from transformers import (
DataCollatorForLanguageModeling,
AutoModelForMaskedLM
)
import random
import numpy as np
import torch
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
class MonitoringTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
# Perform the regular training step
outputs = super().training_step(model, inputs, num_items_in_batch)
# Log gradient norms
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm(2).item()
wandb.log({f'grad_norm/{name}': grad_norm})
# Log weight norms
for name, param in model.named_parameters():
weight_norm = param.data.norm(2).item()
wandb.log({f'weight_norm/{name}': weight_norm})
return outputs
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
@ -50,100 +37,90 @@ def set_seed(seed: int):
def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f:
return yaml.safe_load(f)
return json.load(f)
def initialize_wandb(config, name, files_to_save):
wandb.init(project=config['experiment']['wandb_project'], config=config, name=name)
for file in files_to_save:
wandb.save(file)
def main():
# Setup paths
current_dir = Path(__file__).parent
config = load_config(current_dir / 'config.json')
output_dir = Path(config['output_dir'])
def prepare_dataset(config, cache_dir):
dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir)
if config['data']['num_samples'] > 0:
dataset = dataset.select(range(config['data']['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size'])
# Set seed
set_seed(config['seed'])
# Initialize W&B and save files
wandb.init(project=config['project'], config=config, name=config['run_name'])
for file in [__file__, 'config.json', 'tree_starencoder.py']:
if config['extra_embeddings'] or file != 'tree_starencoder.py':
wandb.save(current_dir / file)
# Simplified dataset splitting
dataset = load_dataset(config['dataset'], split='train')
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split(
test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']),
seed=config['training']['seed']
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
seed=config['seed']
)
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
if config['model']['extra_embeddings']:
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
# Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if config['extra_embeddings']:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove)
return dataset
def main():
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
args = parser.parse_args()
current_dir = Path(__file__).parent
config_path = Path(args.config)
config = load_config(config_path)
cache_dir = Path(config['experiment']['cache_dir'])
output_dir = Path('./outputs') / config_path.stem
if config['experiment']['use_wandb']:
os.environ['WANDB_MODE'] = 'online'
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
else:
os.environ['WANDB_MODE'] = 'offline'
logger.info('Wandb is not used.')
set_seed(config['training']['seed'])
dataset = prepare_dataset(config, cache_dir)
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
logger.info(f'Dataset columns: {dataset["train"].column_names}')
logger.info(f'Loaded dataset:\n{dataset}')
# Simplify tokenizer setup
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
if tokenizer.mask_token is None:
tokenizer.add_special_tokens({'mask_token': '<mask>'})
tokenizer.mask_token = '<mask>'
logger.info("Added '<mask>' as the mask token.")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Set padding token to be the same as the EOS token.")
tokenizer.add_special_tokens({'mask_token': '<mask>'}) if tokenizer.mask_token is None else None
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir)
if config['model']['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config, yaml_config=config)
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config)
else:
model = AutoModelForMaskedLM.from_config(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}')
# Setup training arguments
training_args = TrainingArguments(
output_dir=str(output_dir),
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'],
learning_rate=config['training']['learning_rate'],
weight_decay=config['training']['weight_decay'],
num_train_epochs=config['training']['epochs'],
warmup_steps=config['training']['warmup_steps'],
max_grad_norm=config['training']['max_grad_norm'],
logging_steps=config['evaluation']['logging_steps'],
eval_steps=config['evaluation']['eval_every'],
save_steps=config['evaluation']['eval_every'],
per_device_train_batch_size=config['batch_size'],
per_device_eval_batch_size=config['batch_size'],
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
num_train_epochs=config['epochs'],
warmup_steps=config['warmup_steps'],
max_grad_norm=config['max_grad_norm'],
logging_steps=config['logging_steps'],
eval_steps=config['eval_every'],
save_steps=config['eval_every'],
eval_strategy='steps',
save_strategy='steps',
save_total_limit=5,
load_best_model_at_end=True,
report_to='wandb' if config['experiment']['use_wandb'] else None,
run_name=config_path.stem,
seed=config['training']['seed'],
fp16=config['training']['fp16'],
report_to='wandb',
run_name=config['run_name'],
seed=config['seed'],
bf16=config['bf16'],
dataloader_num_workers=8,
gradient_checkpointing=True,
metric_for_best_model='eval_loss',
greater_is_better=False,
save_total_limit=3,
)
trainer = MonitoringTrainer(
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'],
@ -151,25 +128,32 @@ def main():
data_collator=DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config['data']['mlm_probability']
mlm_probability=config['mlm_probability']
),
)
# Train
logger.info('Starting training...')
trainer.train()
logger.info('Training completed.')
# Evaluate
logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(dataset['test'])
logger.info(f'Evaluation results: {eval_results}')
wandb.log({'test_loss': eval_results['eval_loss']})
logger.info(f'Test loss: {eval_results["eval_loss"]}')
# Save final model
logger.info('Saving final model...')
trainer.save_model(output_dir / 'final-model')
tokenizer.save_pretrained(output_dir / 'final-model')
# Upload to W&B
wandb.save(output_dir / 'final-model/*')
# Zip and upload the final model to W&B
with zipfile.ZipFile(output_dir / 'final-model.zip', 'w') as zipf:
for file in (output_dir / 'final-model').glob('**/*'):
zipf.write(file, arcname=file.name)
wandb.save(output_dir / 'final-model.zip')
logger.info('Training completed!')
if __name__ == '__main__':
main()
main()

208
code/src/tree_codebert.py Normal file
View File

@ -0,0 +1,208 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import RobertaConfig, RobertaForMaskedLM
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
wandb.log(
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
step=kwargs.get("global_step", None)
)
wandb.log(
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
step=kwargs.get("global_step", None)
)
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}

View File

@ -1,107 +1,48 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import AutoConfig, BertForMaskedLM, GenerationMixin
from transformers import AutoConfig, BertForMaskedLM
class TreePositionalEmbedding(nn.Module):
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
from tree_codebert import TreePositionalEmbedding
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
depth_embeddings = self.depth_embedding(depths)
sibling_embeddings = self.sibling_embedding(sibling_idxs)
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
return self.layer_norm(embeddings)
class NewEmbeddings(nn.Module):
"""Construct the embeddings from word, position and tree-based embeddings.
"""
def __init__(self, config, yaml_config):
super().__init__()
self.yaml_config = yaml_config
if self.yaml_config['model']['concat_embeddings']:
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
nn.GELU(),
nn.Linear(config.hidden_size * 3, config.hidden_size),
)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=yaml_config['model']['max_depth'],
dropout=config.hidden_dropout_prob
)
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
input_shape = input_ids.size()
seq_length = input_shape[1]
device = input_ids.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
if self.yaml_config['model']['sum_embeddings']:
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
if self.yaml_config['model']['concat_embeddings']:
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
embeddings = self.fusion_layer(embeddings)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
def __init__(self, config: AutoConfig, yaml_config: Dict):
class TreeStarEncoderForPreTraining(BertForMaskedLM):
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.config = config
self.yaml_config = yaml_config
self.embeddings = NewEmbeddings(config, yaml_config)
# self.fusion_layer = nn.Sequential(
# nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
# nn.Linear(config.hidden_size * 4, config.hidden_size),
# nn.GELU(),
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
# nn.Dropout(config.hidden_dropout_prob),
# nn.LayerNorm(config.hidden_size)
# )
# Override config to set max_seq_length
config.max_position_embeddings = max_seq_length
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# # Initialize sequential position embeddings with sinusoidal pattern
# position = torch.arange(max_seq_length).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
# pe = torch.zeros(max_seq_length, config.hidden_size)
# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# self.seq_pos_embeddings.weight.data.copy_(pe)
# New node type embeddings
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
@ -112,10 +53,34 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
embedding_output = self.embeddings(input_ids, depths, sibling_idxs)
device = input_ids.device
# Get token embeddings
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# # Get node type embeddings
# node_type_embeddings = self.node_type_embeddings(node_types)
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
# combined_embeddings = self.fusion_layer(combined)
# Add the embeddings instead of concatenating
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
combined_embeddings = self.norm(combined_embeddings)
outputs = self.bert(
inputs_embeds=embedding_output,
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
@ -135,7 +100,6 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
return {
"loss": masked_lm_loss,
"logits": prediction_scores,
"last_hidden_state": sequence_output,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}