Compare commits

..

3 Commits

Author SHA1 Message Date
Patryk Bartkowiak
35e5d3e8fa original 2025-01-03 10:16:18 +00:00
Patryk Bartkowiak
3d6826f058 using hf to load online dataset 2025-01-03 06:12:17 +00:00
Patryk Bartkowiak
dfb1e669bd ready for runpod 2025-01-02 20:36:05 +00:00
22 changed files with 4143 additions and 401 deletions

1
code/.gitignore vendored
View File

@ -164,4 +164,3 @@ cython_debug/
# Weights & Biases # Weights & Biases
wandb/ wandb/
outputs/ outputs/
cache/

View File

@ -20,9 +20,9 @@ pdm install
``` ```
### 4. Run training code ### 4. Run training code
```bash ```bash
pdm train --config {CONFIG FILE} pdm run_training
``` ```
Example: or
```bash ```
pdm train --config ./configs/original.yaml pdm run src/train_codebert_mlm.py
``` ```

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: false
wandb_project: "prof-gralinski"
training:
seed: 42
batch_size: 32
epochs: 1
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 0
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/1-fold-clone-detection-600k-5fold"
num_samples: 1000 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
path: "outputs/tmp/final-model"
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: false # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,33 +0,0 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

2
code/data/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

2
code/models/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -19,7 +19,6 @@ dependencies = [
"tree-sitter-python==0.23.4", "tree-sitter-python==0.23.4",
"ipykernel==6.29.5", "ipykernel==6.29.5",
"ipywidgets==8.1.5", "ipywidgets==8.1.5",
"pyyaml==6.0.2",
] ]
requires-python = "==3.11.*" requires-python = "==3.11.*"
readme = "README.md" readme = "README.md"
@ -42,4 +41,3 @@ distribution = true
parse_dataset = {cmd = "src/parse_dataset.py"} parse_dataset = {cmd = "src/parse_dataset.py"}
train = {cmd = "src/training.py"} train = {cmd = "src/training.py"}
eval = {cmd = "src/eval_model.py"} eval = {cmd = "src/eval_model.py"}
finetune = {cmd = "src_finetune/finetune.py"}

21
code/src/config.json Normal file
View File

@ -0,0 +1,21 @@
{
"project": "runpod",
"run_name": "original",
"dataset": "patrykbart/codeparrot-clean-no-comments-starencoder-small",
"output_dir": "./outputs/long-no-comments-starencoder-original",
"extra_embeddings": false,
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 192,
"epochs": 3,
"eval_every": 2500,
"learning_rate": 5e-4,
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 500,
"bf16": true,
"logging_steps": 100,
"valid_size": 0.05,
"test_size": 0.05,
"num_samples": -1
}

View File

@ -120,7 +120,7 @@ def main():
# Setup paths # Setup paths
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
config = load_config(current_dir / 'eval_config.json') config = load_config(current_dir / 'eval_config.json')
model_dir = Path(config['model_dir']) / 'final-model' model_dir = Path(config['model_dir'])
data_dir = Path(config['data_dir']) data_dir = Path(config['data_dir'])
results_dir = Path(config['model_dir']) / 'evaluation_results' results_dir = Path(config['model_dir']) / 'evaluation_results'
results_dir.mkdir(exist_ok=True) results_dir.mkdir(exist_ok=True)

3754
code/src/node_types.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -84,6 +84,7 @@ def process_example(code, tokenizer):
depths = [-1] * len(input_ids) depths = [-1] * len(input_ids)
sibling_idxs = [-1] * len(input_ids) sibling_idxs = [-1] * len(input_ids)
node_types = [None] * len(input_ids)
node_texts = [''] * len(input_ids) node_texts = [''] * len(input_ids)
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids) tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
@ -95,6 +96,7 @@ def process_example(code, tokenizer):
if node.start_byte <= start < node.end_byte: if node.start_byte <= start < node.end_byte:
depths[i] = depth depths[i] = depth
sibling_idxs[i] = sibling_idx sibling_idxs[i] = sibling_idx
node_types[i] = node.type
node_texts[i] = code[node.start_byte:node.end_byte] node_texts[i] = code[node.start_byte:node.end_byte]
for i, child in enumerate(node.children): for i, child in enumerate(node.children):
traverse(child, depth + 1, i) traverse(child, depth + 1, i)
@ -107,6 +109,7 @@ def process_example(code, tokenizer):
'attention_mask': attention_mask, 'attention_mask': attention_mask,
'depths': depths, 'depths': depths,
'sibling_idxs': sibling_idxs, 'sibling_idxs': sibling_idxs,
'node_types': node_types,
'node_texts': node_texts 'node_texts': node_texts
} }
@ -118,6 +121,7 @@ def process_batch(batch, tokenizer):
processed_depths = [] processed_depths = []
processed_sibling_idxs = [] processed_sibling_idxs = []
processed_node_texts = [] processed_node_texts = []
processed_node_types = []
for content in contents: for content in contents:
try: try:
@ -130,6 +134,7 @@ def process_batch(batch, tokenizer):
processed_depths.append([]) processed_depths.append([])
processed_sibling_idxs.append([]) processed_sibling_idxs.append([])
processed_node_texts.append([]) processed_node_texts.append([])
processed_node_types.append([])
else: else:
processed_input_ids.append(result['input_ids']) processed_input_ids.append(result['input_ids'])
processed_attention_mask.append(result['attention_mask']) processed_attention_mask.append(result['attention_mask'])
@ -137,6 +142,7 @@ def process_batch(batch, tokenizer):
processed_depths.append(result['depths']) processed_depths.append(result['depths'])
processed_sibling_idxs.append(result['sibling_idxs']) processed_sibling_idxs.append(result['sibling_idxs'])
processed_node_texts.append(result['node_texts']) processed_node_texts.append(result['node_texts'])
processed_node_types.append(result['node_types'])
except Exception: except Exception:
# If something unexpected happens # If something unexpected happens
processed_input_ids.append([]) processed_input_ids.append([])
@ -145,6 +151,7 @@ def process_batch(batch, tokenizer):
processed_depths.append([]) processed_depths.append([])
processed_sibling_idxs.append([]) processed_sibling_idxs.append([])
processed_node_texts.append([]) processed_node_texts.append([])
processed_node_types.append([])
return { return {
'input_ids': processed_input_ids, 'input_ids': processed_input_ids,
@ -152,6 +159,7 @@ def process_batch(batch, tokenizer):
'tokens': processed_tokens, 'tokens': processed_tokens,
'depths': processed_depths, 'depths': processed_depths,
'sibling_idxs': processed_sibling_idxs, 'sibling_idxs': processed_sibling_idxs,
'node_types': processed_node_types,
'node_texts': processed_node_texts 'node_texts': processed_node_texts
} }

View File

@ -1,13 +1,9 @@
import os
import wandb import wandb
import argparse import json
import yaml
import torch
import random
import logging import logging
import numpy as np import zipfile
from pathlib import Path from pathlib import Path
from datasets import load_dataset, DatasetDict from datasets import load_from_disk, DatasetDict, load_dataset
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
@ -16,29 +12,20 @@ from transformers import (
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
AutoModelForMaskedLM AutoModelForMaskedLM
) )
import random
import numpy as np
import torch
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MonitoringTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
# Perform the regular training step
outputs = super().training_step(model, inputs, num_items_in_batch)
# Log gradient norms
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm(2).item()
wandb.log({f'grad_norm/{name}': grad_norm})
# Log weight norms
for name, param in model.named_parameters():
weight_norm = param.data.norm(2).item()
wandb.log({f'weight_norm/{name}': weight_norm})
return outputs
def set_seed(seed: int): def set_seed(seed: int):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
@ -50,100 +37,90 @@ def set_seed(seed: int):
def load_config(config_path: Path) -> dict: def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
return yaml.safe_load(f) return json.load(f)
def initialize_wandb(config, name, files_to_save): def main():
wandb.init(project=config['experiment']['wandb_project'], config=config, name=name) # Setup paths
for file in files_to_save: current_dir = Path(__file__).parent
wandb.save(file) config = load_config(current_dir / 'config.json')
output_dir = Path(config['output_dir'])
def prepare_dataset(config, cache_dir): # Set seed
dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir) set_seed(config['seed'])
if config['data']['num_samples'] > 0:
dataset = dataset.select(range(config['data']['num_samples'])) # Initialize W&B and save files
train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size']) wandb.init(project=config['project'], config=config, name=config['run_name'])
for file in [__file__, 'config.json', 'tree_starencoder.py']:
if config['extra_embeddings'] or file != 'tree_starencoder.py':
wandb.save(current_dir / file)
# Simplified dataset splitting
dataset = load_dataset(config['dataset'], split='train')
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split( test_valid = train_testvalid['test'].train_test_split(
test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']), test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
seed=config['training']['seed'] seed=config['seed']
) )
dataset = DatasetDict({ dataset = DatasetDict({
'train': train_testvalid['train'], 'train': train_testvalid['train'],
'test': test_valid['test'], 'test': test_valid['test'],
'valid': test_valid['train'], 'valid': test_valid['train'],
}) })
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
if config['model']['extra_embeddings']:
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']] # Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if config['extra_embeddings']:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove) dataset = dataset.remove_columns(columns_to_remove)
return dataset logger.info(f'Loaded dataset:\n{dataset}')
def main():
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
args = parser.parse_args()
current_dir = Path(__file__).parent
config_path = Path(args.config)
config = load_config(config_path)
cache_dir = Path(config['experiment']['cache_dir'])
output_dir = Path('./outputs') / config_path.stem
if config['experiment']['use_wandb']:
os.environ['WANDB_MODE'] = 'online'
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
else:
os.environ['WANDB_MODE'] = 'offline'
logger.info('Wandb is not used.')
set_seed(config['training']['seed'])
dataset = prepare_dataset(config, cache_dir)
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
logger.info(f'Dataset columns: {dataset["train"].column_names}')
# Simplify tokenizer setup
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder') tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
if tokenizer.mask_token is None: tokenizer.add_special_tokens({'mask_token': '<mask>'}) if tokenizer.mask_token is None else None
tokenizer.add_special_tokens({'mask_token': '<mask>'}) tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
tokenizer.mask_token = '<mask>'
logger.info("Added '<mask>' as the mask token.")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Set padding token to be the same as the EOS token.")
model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir) model_config = AutoConfig.from_pretrained('bigcode/starencoder')
if config['model']['extra_embeddings']: if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config, yaml_config=config) model = TreeStarEncoderForPreTraining(model_config)
else: else:
model = AutoModelForMaskedLM.from_config(model_config) model = AutoModelForMaskedLM.from_config(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}') logger.info(f'Loaded model: {model.__class__.__name__}')
# Setup training arguments
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=str(output_dir), output_dir=str(output_dir),
per_device_train_batch_size=config['training']['batch_size'], per_device_train_batch_size=config['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'], per_device_eval_batch_size=config['batch_size'],
learning_rate=config['training']['learning_rate'], learning_rate=config['learning_rate'],
weight_decay=config['training']['weight_decay'], weight_decay=config['weight_decay'],
num_train_epochs=config['training']['epochs'], num_train_epochs=config['epochs'],
warmup_steps=config['training']['warmup_steps'], warmup_steps=config['warmup_steps'],
max_grad_norm=config['training']['max_grad_norm'], max_grad_norm=config['max_grad_norm'],
logging_steps=config['evaluation']['logging_steps'], logging_steps=config['logging_steps'],
eval_steps=config['evaluation']['eval_every'], eval_steps=config['eval_every'],
save_steps=config['evaluation']['eval_every'], save_steps=config['eval_every'],
eval_strategy='steps', eval_strategy='steps',
save_strategy='steps', save_strategy='steps',
save_total_limit=5,
load_best_model_at_end=True, load_best_model_at_end=True,
report_to='wandb' if config['experiment']['use_wandb'] else None, report_to='wandb',
run_name=config_path.stem, run_name=config['run_name'],
seed=config['training']['seed'], seed=config['seed'],
fp16=config['training']['fp16'], bf16=config['bf16'],
dataloader_num_workers=8, dataloader_num_workers=8,
gradient_checkpointing=True, gradient_checkpointing=True,
metric_for_best_model='eval_loss', metric_for_best_model='eval_loss',
greater_is_better=False, greater_is_better=False,
save_total_limit=3,
) )
trainer = MonitoringTrainer( # Create trainer
trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=dataset['train'], train_dataset=dataset['train'],
@ -151,25 +128,32 @@ def main():
data_collator=DataCollatorForLanguageModeling( data_collator=DataCollatorForLanguageModeling(
tokenizer=tokenizer, tokenizer=tokenizer,
mlm=True, mlm=True,
mlm_probability=config['data']['mlm_probability'] mlm_probability=config['mlm_probability']
), ),
) )
# Train
logger.info('Starting training...') logger.info('Starting training...')
trainer.train() trainer.train()
logger.info('Training completed.')
# Evaluate
logger.info('Evaluating on test set...') logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(dataset['test']) eval_results = trainer.evaluate(dataset['test'])
logger.info(f'Evaluation results: {eval_results}')
wandb.log({'test_loss': eval_results['eval_loss']}) wandb.log({'test_loss': eval_results['eval_loss']})
logger.info(f'Test loss: {eval_results["eval_loss"]}')
# Save final model
logger.info('Saving final model...') logger.info('Saving final model...')
trainer.save_model(output_dir / 'final-model') trainer.save_model(output_dir / 'final-model')
tokenizer.save_pretrained(output_dir / 'final-model') tokenizer.save_pretrained(output_dir / 'final-model')
# Upload to W&B # Zip and upload the final model to W&B
wandb.save(output_dir / 'final-model/*') with zipfile.ZipFile(output_dir / 'final-model.zip', 'w') as zipf:
for file in (output_dir / 'final-model').glob('**/*'):
zipf.write(file, arcname=file.name)
wandb.save(output_dir / 'final-model.zip')
logger.info('Training completed!')
if __name__ == '__main__': if __name__ == '__main__':
main() main()

208
code/src/tree_codebert.py Normal file
View File

@ -0,0 +1,208 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import RobertaConfig, RobertaForMaskedLM
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
wandb.log(
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
step=kwargs.get("global_step", None)
)
wandb.log(
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
step=kwargs.get("global_step", None)
)
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}

View File

@ -1,107 +1,48 @@
import wandb
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict, Optional from typing import Dict, Optional
from transformers import AutoConfig, BertForMaskedLM, GenerationMixin from transformers import AutoConfig, BertForMaskedLM
class TreePositionalEmbedding(nn.Module): from tree_codebert import TreePositionalEmbedding
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self): class TreeStarEncoderForPreTraining(BertForMaskedLM):
std = 0.02 def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
depth_embeddings = self.depth_embedding(depths)
sibling_embeddings = self.sibling_embedding(sibling_idxs)
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
return self.layer_norm(embeddings)
class NewEmbeddings(nn.Module):
"""Construct the embeddings from word, position and tree-based embeddings.
"""
def __init__(self, config, yaml_config):
super().__init__()
self.yaml_config = yaml_config
if self.yaml_config['model']['concat_embeddings']:
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
nn.GELU(),
nn.Linear(config.hidden_size * 3, config.hidden_size),
)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=yaml_config['model']['max_depth'],
dropout=config.hidden_dropout_prob
)
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
input_shape = input_ids.size()
seq_length = input_shape[1]
device = input_ids.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
if self.yaml_config['model']['sum_embeddings']:
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
if self.yaml_config['model']['concat_embeddings']:
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
embeddings = self.fusion_layer(embeddings)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
def __init__(self, config: AutoConfig, yaml_config: Dict):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.yaml_config = yaml_config
self.embeddings = NewEmbeddings(config, yaml_config)
# self.fusion_layer = nn.Sequential( # self.fusion_layer = nn.Sequential(
# nn.Linear(config.hidden_size * 3, config.hidden_size * 3), # nn.Linear(config.hidden_size * 4, config.hidden_size),
# nn.GELU(), # nn.GELU(),
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
# nn.Dropout(config.hidden_dropout_prob), # nn.Dropout(config.hidden_dropout_prob),
# nn.LayerNorm(config.hidden_size) # nn.LayerNorm(config.hidden_size)
# ) # )
# Override config to set max_seq_length
config.max_position_embeddings = max_seq_length
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# # Initialize sequential position embeddings with sinusoidal pattern
# position = torch.arange(max_seq_length).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
# pe = torch.zeros(max_seq_length, config.hidden_size)
# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# self.seq_pos_embeddings.weight.data.copy_(pe)
# New node type embeddings
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -112,10 +53,34 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
output_attentions: bool = False, output_attentions: bool = False,
**kwargs **kwargs
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
embedding_output = self.embeddings(input_ids, depths, sibling_idxs) device = input_ids.device
# Get token embeddings
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# # Get node type embeddings
# node_type_embeddings = self.node_type_embeddings(node_types)
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
# combined_embeddings = self.fusion_layer(combined)
# Add the embeddings instead of concatenating
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
combined_embeddings = self.norm(combined_embeddings)
outputs = self.bert( outputs = self.bert(
inputs_embeds=embedding_output, inputs_embeds=combined_embeddings,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs **kwargs
@ -135,7 +100,6 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
return { return {
"loss": masked_lm_loss, "loss": masked_lm_loss,
"logits": prediction_scores, "logits": prediction_scores,
"last_hidden_state": sequence_output,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None, "hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None, "attentions": outputs.attentions if output_attentions else None,
} }