Compare commits
3 Commits
master
...
runpod-exp
Author | SHA1 | Date | |
---|---|---|---|
|
35e5d3e8fa | ||
|
3d6826f058 | ||
|
dfb1e669bd |
1
code/.gitignore
vendored
1
code/.gitignore
vendored
@ -164,4 +164,3 @@ cython_debug/
|
||||
# Weights & Biases
|
||||
wandb/
|
||||
outputs/
|
||||
cache/
|
||||
|
@ -20,9 +20,9 @@ pdm install
|
||||
```
|
||||
### 4. Run training code
|
||||
```bash
|
||||
pdm train --config {CONFIG FILE}
|
||||
pdm run_training
|
||||
```
|
||||
Example:
|
||||
```bash
|
||||
pdm train --config ./configs/original.yaml
|
||||
or
|
||||
```
|
||||
pdm run src/train_codebert_mlm.py
|
||||
```
|
@ -1,33 +0,0 @@
|
||||
experiment:
|
||||
cache_dir: "./cache/"
|
||||
use_wandb: false
|
||||
wandb_project: "prof-gralinski"
|
||||
|
||||
training:
|
||||
seed: 42
|
||||
batch_size: 32
|
||||
epochs: 1
|
||||
learning_rate: 0.0005
|
||||
weight_decay: 0.1
|
||||
max_grad_norm: 1.0
|
||||
warmup_steps: 0
|
||||
fp16: true
|
||||
|
||||
evaluation:
|
||||
eval_every: 10000
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
source: "patrykbart/1-fold-clone-detection-600k-5fold"
|
||||
num_samples: 1000 # -1 means all samples
|
||||
valid_size: 0.05
|
||||
test_size: 0.05
|
||||
|
||||
model:
|
||||
path: "outputs/tmp/final-model"
|
||||
extra_embeddings: true # if false, use the original star encoder
|
||||
sinusoidal_init: false
|
||||
concat_embeddings: true
|
||||
sum_embeddings: false
|
||||
max_depth: 32
|
||||
max_seq_length: 512
|
@ -1,33 +0,0 @@
|
||||
experiment:
|
||||
cache_dir: "./cache/"
|
||||
use_wandb: true
|
||||
wandb_project: "runpod"
|
||||
|
||||
training:
|
||||
seed: 42
|
||||
batch_size: 32
|
||||
epochs: 3
|
||||
learning_rate: 0.0005
|
||||
weight_decay: 0.1
|
||||
max_grad_norm: 1.0
|
||||
warmup_steps: 1000
|
||||
fp16: true
|
||||
|
||||
evaluation:
|
||||
eval_every: 10000
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||
mlm_probability: 0.15
|
||||
num_samples: -1 # -1 means all samples
|
||||
valid_size: 0.05
|
||||
test_size: 0.05
|
||||
|
||||
model:
|
||||
extra_embeddings: false # if false, use the original star encoder
|
||||
sinusoidal_init: false
|
||||
concat_embeddings: false
|
||||
sum_embeddings: false
|
||||
max_depth: 32
|
||||
max_seq_length: 512
|
@ -1,33 +0,0 @@
|
||||
experiment:
|
||||
cache_dir: "./cache/"
|
||||
use_wandb: true
|
||||
wandb_project: "runpod"
|
||||
|
||||
training:
|
||||
seed: 42
|
||||
batch_size: 32
|
||||
epochs: 3
|
||||
learning_rate: 0.0005
|
||||
weight_decay: 0.1
|
||||
max_grad_norm: 1.0
|
||||
warmup_steps: 1000
|
||||
fp16: true
|
||||
|
||||
evaluation:
|
||||
eval_every: 10000
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||
mlm_probability: 0.15
|
||||
num_samples: -1 # -1 means all samples
|
||||
valid_size: 0.05
|
||||
test_size: 0.05
|
||||
|
||||
model:
|
||||
extra_embeddings: true # if false, use the original star encoder
|
||||
sinusoidal_init: true
|
||||
concat_embeddings: false
|
||||
sum_embeddings: true
|
||||
max_depth: 32
|
||||
max_seq_length: 512
|
@ -1,33 +0,0 @@
|
||||
experiment:
|
||||
cache_dir: "./cache/"
|
||||
use_wandb: true
|
||||
wandb_project: "runpod"
|
||||
|
||||
training:
|
||||
seed: 42
|
||||
batch_size: 32
|
||||
epochs: 3
|
||||
learning_rate: 0.0005
|
||||
weight_decay: 0.1
|
||||
max_grad_norm: 1.0
|
||||
warmup_steps: 1000
|
||||
fp16: true
|
||||
|
||||
evaluation:
|
||||
eval_every: 10000
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||
mlm_probability: 0.15
|
||||
num_samples: -1 # -1 means all samples
|
||||
valid_size: 0.05
|
||||
test_size: 0.05
|
||||
|
||||
model:
|
||||
extra_embeddings: true # if false, use the original star encoder
|
||||
sinusoidal_init: false
|
||||
concat_embeddings: false
|
||||
sum_embeddings: true
|
||||
max_depth: 32
|
||||
max_seq_length: 512
|
@ -1,33 +0,0 @@
|
||||
experiment:
|
||||
cache_dir: "./cache/"
|
||||
use_wandb: true
|
||||
wandb_project: "runpod"
|
||||
|
||||
training:
|
||||
seed: 42
|
||||
batch_size: 32
|
||||
epochs: 3
|
||||
learning_rate: 0.0005
|
||||
weight_decay: 0.1
|
||||
max_grad_norm: 1.0
|
||||
warmup_steps: 1000
|
||||
fp16: true
|
||||
|
||||
evaluation:
|
||||
eval_every: 10000
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||
mlm_probability: 0.15
|
||||
num_samples: -1 # -1 means all samples
|
||||
valid_size: 0.05
|
||||
test_size: 0.05
|
||||
|
||||
model:
|
||||
extra_embeddings: true # if false, use the original star encoder
|
||||
sinusoidal_init: true
|
||||
concat_embeddings: true
|
||||
sum_embeddings: false
|
||||
max_depth: 32
|
||||
max_seq_length: 512
|
@ -1,33 +0,0 @@
|
||||
experiment:
|
||||
cache_dir: "./cache/"
|
||||
use_wandb: true
|
||||
wandb_project: "runpod"
|
||||
|
||||
training:
|
||||
seed: 42
|
||||
batch_size: 32
|
||||
epochs: 3
|
||||
learning_rate: 0.0005
|
||||
weight_decay: 0.1
|
||||
max_grad_norm: 1.0
|
||||
warmup_steps: 1000
|
||||
fp16: true
|
||||
|
||||
evaluation:
|
||||
eval_every: 10000
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||
mlm_probability: 0.15
|
||||
num_samples: -1 # -1 means all samples
|
||||
valid_size: 0.05
|
||||
test_size: 0.05
|
||||
|
||||
model:
|
||||
extra_embeddings: true # if false, use the original star encoder
|
||||
sinusoidal_init: false
|
||||
concat_embeddings: true
|
||||
sum_embeddings: false
|
||||
max_depth: 32
|
||||
max_seq_length: 512
|
2
code/data/.gitignore
vendored
Normal file
2
code/data/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
2
code/models/.gitignore
vendored
Normal file
2
code/models/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
@ -19,7 +19,6 @@ dependencies = [
|
||||
"tree-sitter-python==0.23.4",
|
||||
"ipykernel==6.29.5",
|
||||
"ipywidgets==8.1.5",
|
||||
"pyyaml==6.0.2",
|
||||
]
|
||||
requires-python = "==3.11.*"
|
||||
readme = "README.md"
|
||||
@ -42,4 +41,3 @@ distribution = true
|
||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||
train = {cmd = "src/training.py"}
|
||||
eval = {cmd = "src/eval_model.py"}
|
||||
finetune = {cmd = "src_finetune/finetune.py"}
|
||||
|
21
code/src/config.json
Normal file
21
code/src/config.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"project": "runpod",
|
||||
"run_name": "original",
|
||||
"dataset": "patrykbart/codeparrot-clean-no-comments-starencoder-small",
|
||||
"output_dir": "./outputs/long-no-comments-starencoder-original",
|
||||
"extra_embeddings": false,
|
||||
"seed": 420,
|
||||
"mlm_probability": 0.15,
|
||||
"batch_size": 192,
|
||||
"epochs": 3,
|
||||
"eval_every": 2500,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.1,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 500,
|
||||
"bf16": true,
|
||||
"logging_steps": 100,
|
||||
"valid_size": 0.05,
|
||||
"test_size": 0.05,
|
||||
"num_samples": -1
|
||||
}
|
@ -120,7 +120,7 @@ def main():
|
||||
# Setup paths
|
||||
current_dir = Path(__file__).parent
|
||||
config = load_config(current_dir / 'eval_config.json')
|
||||
model_dir = Path(config['model_dir']) / 'final-model'
|
||||
model_dir = Path(config['model_dir'])
|
||||
data_dir = Path(config['data_dir'])
|
||||
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
3754
code/src/node_types.json
Normal file
3754
code/src/node_types.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -84,6 +84,7 @@ def process_example(code, tokenizer):
|
||||
|
||||
depths = [-1] * len(input_ids)
|
||||
sibling_idxs = [-1] * len(input_ids)
|
||||
node_types = [None] * len(input_ids)
|
||||
node_texts = [''] * len(input_ids)
|
||||
|
||||
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
@ -95,6 +96,7 @@ def process_example(code, tokenizer):
|
||||
if node.start_byte <= start < node.end_byte:
|
||||
depths[i] = depth
|
||||
sibling_idxs[i] = sibling_idx
|
||||
node_types[i] = node.type
|
||||
node_texts[i] = code[node.start_byte:node.end_byte]
|
||||
for i, child in enumerate(node.children):
|
||||
traverse(child, depth + 1, i)
|
||||
@ -107,6 +109,7 @@ def process_example(code, tokenizer):
|
||||
'attention_mask': attention_mask,
|
||||
'depths': depths,
|
||||
'sibling_idxs': sibling_idxs,
|
||||
'node_types': node_types,
|
||||
'node_texts': node_texts
|
||||
}
|
||||
|
||||
@ -118,6 +121,7 @@ def process_batch(batch, tokenizer):
|
||||
processed_depths = []
|
||||
processed_sibling_idxs = []
|
||||
processed_node_texts = []
|
||||
processed_node_types = []
|
||||
|
||||
for content in contents:
|
||||
try:
|
||||
@ -130,6 +134,7 @@ def process_batch(batch, tokenizer):
|
||||
processed_depths.append([])
|
||||
processed_sibling_idxs.append([])
|
||||
processed_node_texts.append([])
|
||||
processed_node_types.append([])
|
||||
else:
|
||||
processed_input_ids.append(result['input_ids'])
|
||||
processed_attention_mask.append(result['attention_mask'])
|
||||
@ -137,6 +142,7 @@ def process_batch(batch, tokenizer):
|
||||
processed_depths.append(result['depths'])
|
||||
processed_sibling_idxs.append(result['sibling_idxs'])
|
||||
processed_node_texts.append(result['node_texts'])
|
||||
processed_node_types.append(result['node_types'])
|
||||
except Exception:
|
||||
# If something unexpected happens
|
||||
processed_input_ids.append([])
|
||||
@ -145,6 +151,7 @@ def process_batch(batch, tokenizer):
|
||||
processed_depths.append([])
|
||||
processed_sibling_idxs.append([])
|
||||
processed_node_texts.append([])
|
||||
processed_node_types.append([])
|
||||
|
||||
return {
|
||||
'input_ids': processed_input_ids,
|
||||
@ -152,6 +159,7 @@ def process_batch(batch, tokenizer):
|
||||
'tokens': processed_tokens,
|
||||
'depths': processed_depths,
|
||||
'sibling_idxs': processed_sibling_idxs,
|
||||
'node_types': processed_node_types,
|
||||
'node_texts': processed_node_texts
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,9 @@
|
||||
import os
|
||||
import wandb
|
||||
import argparse
|
||||
import yaml
|
||||
import torch
|
||||
import random
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from datasets import load_from_disk, DatasetDict, load_dataset
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@ -16,29 +12,20 @@ from transformers import (
|
||||
DataCollatorForLanguageModeling,
|
||||
AutoModelForMaskedLM
|
||||
)
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tree_codebert import TreeCodeBERTForPreTraining
|
||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MonitoringTrainer(Trainer):
|
||||
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||
# Perform the regular training step
|
||||
outputs = super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
# Log gradient norms
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
grad_norm = param.grad.norm(2).item()
|
||||
wandb.log({f'grad_norm/{name}': grad_norm})
|
||||
|
||||
# Log weight norms
|
||||
for name, param in model.named_parameters():
|
||||
weight_norm = param.data.norm(2).item()
|
||||
wandb.log({f'weight_norm/{name}': weight_norm})
|
||||
|
||||
return outputs
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@ -50,100 +37,90 @@ def set_seed(seed: int):
|
||||
|
||||
def load_config(config_path: Path) -> dict:
|
||||
with open(config_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
return json.load(f)
|
||||
|
||||
def initialize_wandb(config, name, files_to_save):
|
||||
wandb.init(project=config['experiment']['wandb_project'], config=config, name=name)
|
||||
for file in files_to_save:
|
||||
wandb.save(file)
|
||||
def main():
|
||||
# Setup paths
|
||||
current_dir = Path(__file__).parent
|
||||
config = load_config(current_dir / 'config.json')
|
||||
output_dir = Path(config['output_dir'])
|
||||
|
||||
def prepare_dataset(config, cache_dir):
|
||||
dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir)
|
||||
if config['data']['num_samples'] > 0:
|
||||
dataset = dataset.select(range(config['data']['num_samples']))
|
||||
train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size'])
|
||||
# Set seed
|
||||
set_seed(config['seed'])
|
||||
|
||||
# Initialize W&B and save files
|
||||
wandb.init(project=config['project'], config=config, name=config['run_name'])
|
||||
for file in [__file__, 'config.json', 'tree_starencoder.py']:
|
||||
if config['extra_embeddings'] or file != 'tree_starencoder.py':
|
||||
wandb.save(current_dir / file)
|
||||
|
||||
# Simplified dataset splitting
|
||||
dataset = load_dataset(config['dataset'], split='train')
|
||||
if config['num_samples'] > 0:
|
||||
dataset = dataset.select(range(config['num_samples']))
|
||||
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
||||
test_valid = train_testvalid['test'].train_test_split(
|
||||
test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']),
|
||||
seed=config['training']['seed']
|
||||
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
|
||||
seed=config['seed']
|
||||
)
|
||||
dataset = DatasetDict({
|
||||
'train': train_testvalid['train'],
|
||||
'test': test_valid['test'],
|
||||
'valid': test_valid['train'],
|
||||
})
|
||||
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
|
||||
if config['model']['extra_embeddings']:
|
||||
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
|
||||
|
||||
|
||||
# Continue with the rest of processing
|
||||
columns_to_remove = dataset['train'].column_names
|
||||
columns_to_remove.remove('input_ids')
|
||||
columns_to_remove.remove('attention_mask')
|
||||
if config['extra_embeddings']:
|
||||
columns_to_remove.remove('depths')
|
||||
columns_to_remove.remove('sibling_idxs')
|
||||
dataset = dataset.remove_columns(columns_to_remove)
|
||||
return dataset
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
|
||||
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
|
||||
args = parser.parse_args()
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
config_path = Path(args.config)
|
||||
config = load_config(config_path)
|
||||
cache_dir = Path(config['experiment']['cache_dir'])
|
||||
output_dir = Path('./outputs') / config_path.stem
|
||||
|
||||
if config['experiment']['use_wandb']:
|
||||
os.environ['WANDB_MODE'] = 'online'
|
||||
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
|
||||
else:
|
||||
os.environ['WANDB_MODE'] = 'offline'
|
||||
logger.info('Wandb is not used.')
|
||||
|
||||
set_seed(config['training']['seed'])
|
||||
|
||||
dataset = prepare_dataset(config, cache_dir)
|
||||
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
|
||||
logger.info(f'Dataset columns: {dataset["train"].column_names}')
|
||||
|
||||
logger.info(f'Loaded dataset:\n{dataset}')
|
||||
|
||||
# Simplify tokenizer setup
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
||||
if tokenizer.mask_token is None:
|
||||
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
||||
tokenizer.mask_token = '<mask>'
|
||||
logger.info("Added '<mask>' as the mask token.")
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Set padding token to be the same as the EOS token.")
|
||||
tokenizer.add_special_tokens({'mask_token': '<mask>'}) if tokenizer.mask_token is None else None
|
||||
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
|
||||
|
||||
model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir)
|
||||
if config['model']['extra_embeddings']:
|
||||
model = TreeStarEncoderForPreTraining(model_config, yaml_config=config)
|
||||
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
|
||||
if config['extra_embeddings']:
|
||||
model = TreeStarEncoderForPreTraining(model_config)
|
||||
else:
|
||||
model = AutoModelForMaskedLM.from_config(model_config)
|
||||
logger.info(f'Loaded model: {model.__class__.__name__}')
|
||||
|
||||
|
||||
# Setup training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir),
|
||||
per_device_train_batch_size=config['training']['batch_size'],
|
||||
per_device_eval_batch_size=config['training']['batch_size'],
|
||||
learning_rate=config['training']['learning_rate'],
|
||||
weight_decay=config['training']['weight_decay'],
|
||||
num_train_epochs=config['training']['epochs'],
|
||||
warmup_steps=config['training']['warmup_steps'],
|
||||
max_grad_norm=config['training']['max_grad_norm'],
|
||||
logging_steps=config['evaluation']['logging_steps'],
|
||||
eval_steps=config['evaluation']['eval_every'],
|
||||
save_steps=config['evaluation']['eval_every'],
|
||||
per_device_train_batch_size=config['batch_size'],
|
||||
per_device_eval_batch_size=config['batch_size'],
|
||||
learning_rate=config['learning_rate'],
|
||||
weight_decay=config['weight_decay'],
|
||||
num_train_epochs=config['epochs'],
|
||||
warmup_steps=config['warmup_steps'],
|
||||
max_grad_norm=config['max_grad_norm'],
|
||||
logging_steps=config['logging_steps'],
|
||||
eval_steps=config['eval_every'],
|
||||
save_steps=config['eval_every'],
|
||||
eval_strategy='steps',
|
||||
save_strategy='steps',
|
||||
save_total_limit=5,
|
||||
load_best_model_at_end=True,
|
||||
report_to='wandb' if config['experiment']['use_wandb'] else None,
|
||||
run_name=config_path.stem,
|
||||
seed=config['training']['seed'],
|
||||
fp16=config['training']['fp16'],
|
||||
report_to='wandb',
|
||||
run_name=config['run_name'],
|
||||
seed=config['seed'],
|
||||
bf16=config['bf16'],
|
||||
dataloader_num_workers=8,
|
||||
gradient_checkpointing=True,
|
||||
metric_for_best_model='eval_loss',
|
||||
greater_is_better=False,
|
||||
save_total_limit=3,
|
||||
)
|
||||
|
||||
trainer = MonitoringTrainer(
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset['train'],
|
||||
@ -151,25 +128,32 @@ def main():
|
||||
data_collator=DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=config['data']['mlm_probability']
|
||||
mlm_probability=config['mlm_probability']
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Train
|
||||
logger.info('Starting training...')
|
||||
trainer.train()
|
||||
logger.info('Training completed.')
|
||||
|
||||
# Evaluate
|
||||
logger.info('Evaluating on test set...')
|
||||
eval_results = trainer.evaluate(dataset['test'])
|
||||
logger.info(f'Evaluation results: {eval_results}')
|
||||
wandb.log({'test_loss': eval_results['eval_loss']})
|
||||
|
||||
logger.info(f'Test loss: {eval_results["eval_loss"]}')
|
||||
|
||||
# Save final model
|
||||
logger.info('Saving final model...')
|
||||
trainer.save_model(output_dir / 'final-model')
|
||||
tokenizer.save_pretrained(output_dir / 'final-model')
|
||||
|
||||
# Upload to W&B
|
||||
wandb.save(output_dir / 'final-model/*')
|
||||
# Zip and upload the final model to W&B
|
||||
with zipfile.ZipFile(output_dir / 'final-model.zip', 'w') as zipf:
|
||||
for file in (output_dir / 'final-model').glob('**/*'):
|
||||
zipf.write(file, arcname=file.name)
|
||||
wandb.save(output_dir / 'final-model.zip')
|
||||
|
||||
logger.info('Training completed!')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
208
code/src/tree_codebert.py
Normal file
208
code/src/tree_codebert.py
Normal file
@ -0,0 +1,208 @@
|
||||
import wandb
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from transformers import RobertaConfig, RobertaForMaskedLM
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
||||
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
|
||||
# Separate embeddings for different features
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
|
||||
# Improved projection layers
|
||||
self.node_projection = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# Layer norm for stability
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
std = 0.02
|
||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||
|
||||
# Initialize projection layers
|
||||
for layer in self.node_projection:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
||||
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
||||
Returns:
|
||||
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
||||
"""
|
||||
# Clamp values to max_depth
|
||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||
|
||||
# Get embeddings for each feature
|
||||
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
||||
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
||||
|
||||
# Combine features
|
||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||
embeddings = self.node_projection(combined)
|
||||
|
||||
# Apply layer norm
|
||||
normalized_embeddings = self.layer_norm(embeddings)
|
||||
|
||||
return normalized_embeddings
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with tree-structural information."""
|
||||
|
||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
super().__init__(config)
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth,
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# Initialize embedding weights equally
|
||||
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
||||
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
||||
|
||||
# Layer norms for embedding combination
|
||||
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def get_normalized_weights(self):
|
||||
"""Get softmaxed weights for embedding combination."""
|
||||
return torch.softmax(self.embedding_weights, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
sibling_idxs: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
device = input_ids.device
|
||||
|
||||
# Get normalized weights for embedding combination and calculate regularization
|
||||
weights = self.get_normalized_weights()
|
||||
|
||||
# Calculate weight variance regularization
|
||||
# We want weights to remain somewhat balanced, so penalize high variance
|
||||
weight_variance = torch.var(weights)
|
||||
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
||||
|
||||
# Add L2 regularization to prevent any weight from getting too close to 1
|
||||
# This helps maintain a more balanced contribution from each embedding type
|
||||
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
||||
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
||||
|
||||
# Get token embeddings
|
||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||
token_embeddings = self.pre_combination_norm(token_embeddings)
|
||||
|
||||
# Get sequential position embeddings
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get tree positional embeddings if tree information is provided
|
||||
if depths is not None and sibling_idxs is not None:
|
||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
else:
|
||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||
|
||||
# Combine all embeddings using learned weights
|
||||
combined_embeddings = (
|
||||
weights[0] * token_embeddings +
|
||||
weights[1] * tree_embeddings +
|
||||
weights[2] * seq_embeddings
|
||||
)
|
||||
|
||||
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
||||
|
||||
# Forward pass through base model
|
||||
outputs = self.roberta(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
# Calculate MLM loss if labels are provided
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
# Add regularization losses to final loss
|
||||
if masked_lm_loss is not None:
|
||||
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = None
|
||||
|
||||
# Prepare embedding weights for logging
|
||||
weights_cpu = weights.detach().cpu()
|
||||
embedding_weights = {
|
||||
"token": weights_cpu[0].item(),
|
||||
"tree": weights_cpu[1].item(),
|
||||
"sequential": weights_cpu[2].item()
|
||||
}
|
||||
reg_metrics = {
|
||||
"weight_variance": weight_variance.item(),
|
||||
"max_weight_penalty": max_weight_penalty.item(),
|
||||
"weight_reg_loss": weight_reg_loss.item(),
|
||||
"l2_reg_loss": l2_reg_loss.item()
|
||||
}
|
||||
|
||||
wandb.log(
|
||||
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
|
||||
step=kwargs.get("global_step", None)
|
||||
)
|
||||
wandb.log(
|
||||
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
|
||||
step=kwargs.get("global_step", None)
|
||||
)
|
||||
|
||||
return {
|
||||
"loss": final_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions if output_attentions else None,
|
||||
}
|
@ -1,107 +1,48 @@
|
||||
import wandb
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from transformers import AutoConfig, BertForMaskedLM, GenerationMixin
|
||||
from transformers import AutoConfig, BertForMaskedLM
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.node_projection = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
std = 0.02
|
||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||
for layer in self.node_projection:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||
nn.init.zeros_(layer.bias)
|
||||
from tree_codebert import TreePositionalEmbedding
|
||||
|
||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||
depth_embeddings = self.depth_embedding(depths)
|
||||
sibling_embeddings = self.sibling_embedding(sibling_idxs)
|
||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||
embeddings = self.node_projection(combined)
|
||||
return self.layer_norm(embeddings)
|
||||
|
||||
class NewEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and tree-based embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config, yaml_config):
|
||||
super().__init__()
|
||||
self.yaml_config = yaml_config
|
||||
|
||||
if self.yaml_config['model']['concat_embeddings']:
|
||||
self.fusion_layer = nn.Sequential(
|
||||
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
|
||||
nn.GELU(),
|
||||
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
||||
)
|
||||
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=yaml_config['model']['max_depth'],
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
|
||||
input_shape = input_ids.size()
|
||||
|
||||
seq_length = input_shape[1]
|
||||
device = input_ids.device
|
||||
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
|
||||
if self.yaml_config['model']['sum_embeddings']:
|
||||
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
|
||||
if self.yaml_config['model']['concat_embeddings']:
|
||||
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
|
||||
embeddings = self.fusion_layer(embeddings)
|
||||
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
||||
def __init__(self, config: AutoConfig, yaml_config: Dict):
|
||||
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.yaml_config = yaml_config
|
||||
self.embeddings = NewEmbeddings(config, yaml_config)
|
||||
|
||||
# self.fusion_layer = nn.Sequential(
|
||||
# nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
|
||||
# nn.Linear(config.hidden_size * 4, config.hidden_size),
|
||||
# nn.GELU(),
|
||||
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
|
||||
# nn.Dropout(config.hidden_dropout_prob),
|
||||
# nn.LayerNorm(config.hidden_size)
|
||||
# )
|
||||
|
||||
# Override config to set max_seq_length
|
||||
config.max_position_embeddings = max_seq_length
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth,
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# # Initialize sequential position embeddings with sinusoidal pattern
|
||||
# position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
# pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
# pe[:, 0::2] = torch.sin(position * div_term)
|
||||
# pe[:, 1::2] = torch.cos(position * div_term)
|
||||
# self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# New node type embeddings
|
||||
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
||||
self.norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -112,10 +53,34 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
embedding_output = self.embeddings(input_ids, depths, sibling_idxs)
|
||||
device = input_ids.device
|
||||
|
||||
# Get token embeddings
|
||||
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
|
||||
|
||||
# Get sequential position embeddings
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get tree positional embeddings
|
||||
if depths is not None and sibling_idxs is not None:
|
||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
else:
|
||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||
|
||||
# # Get node type embeddings
|
||||
# node_type_embeddings = self.node_type_embeddings(node_types)
|
||||
|
||||
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
||||
# combined_embeddings = self.fusion_layer(combined)
|
||||
|
||||
# Add the embeddings instead of concatenating
|
||||
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
|
||||
|
||||
combined_embeddings = self.norm(combined_embeddings)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs_embeds=embedding_output,
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
@ -135,7 +100,6 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
||||
return {
|
||||
"loss": masked_lm_loss,
|
||||
"logits": prediction_scores,
|
||||
"last_hidden_state": sequence_output,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions if output_attentions else None,
|
||||
}
|
Loading…
Reference in New Issue
Block a user