Compare commits

...

9 Commits

Author SHA1 Message Date
Patryk Bartkowiak
a6ebe812cf added more monitoring to the training 2025-01-21 10:56:16 +00:00
Patryk Bartkowiak
19864729cf max 3 checkpoints 2025-01-20 23:46:03 +00:00
Patryk Bartkowiak
2639e8dca2 changed configs for runpod 2025-01-20 21:14:20 +00:00
Patryk Bartkowiak
2ec7cc263d updated README 2025-01-16 17:06:36 +00:00
Patryk Bartkowiak
03b2502af0 simplified all the files and prepared config for each experiment 2025-01-16 17:03:53 +00:00
Patryk Bartkowiak
3b93a7cc8a prepared for run by prof Filip Gralinski (done) 2025-01-07 15:59:30 +00:00
Patryk Bartkowiak
76a89dc236 3 pochs 2025-01-07 12:58:20 +00:00
Patryk Bartkowiak
eed2096400 prepared for run by prof Filip Gralinski 2025-01-07 11:45:55 +00:00
Patryk Bartkowiak
f0679ab861 on this commit i continued to train original starencoder model 2025-01-04 21:02:30 +00:00
24 changed files with 401 additions and 4346 deletions

1
code/.gitignore vendored
View File

@ -164,3 +164,4 @@ cython_debug/
# Weights & Biases
wandb/
outputs/
cache/

View File

@ -20,9 +20,9 @@ pdm install
```
### 4. Run training code
```bash
pdm run_training
pdm train --config {CONFIG FILE}
```
or
```
pdm run src/train_codebert_mlm.py
Example:
```bash
pdm train --config ./configs/original.yaml
```

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: false
wandb_project: "prof-gralinski"
training:
seed: 42
batch_size: 32
epochs: 1
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 0
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/1-fold-clone-detection-600k-5fold"
num_samples: 1000 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
path: "outputs/tmp/final-model"
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: false # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,2 +0,0 @@
*
!.gitignore

View File

@ -1,2 +0,0 @@
*
!.gitignore

View File

@ -19,6 +19,7 @@ dependencies = [
"tree-sitter-python==0.23.4",
"ipykernel==6.29.5",
"ipywidgets==8.1.5",
"pyyaml==6.0.2",
]
requires-python = "==3.11.*"
readme = "README.md"
@ -41,3 +42,4 @@ distribution = true
parse_dataset = {cmd = "src/parse_dataset.py"}
train = {cmd = "src/training.py"}
eval = {cmd = "src/eval_model.py"}
finetune = {cmd = "src_finetune/finetune.py"}

View File

@ -1,20 +0,0 @@
{
"extra_embeddings": true,
"run_name": "no-sinusoidal",
"data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/",
"output_dir": "./outputs/long-no-comments-starencoder-no-sinusoidal",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,
"epochs": 3,
"eval_every": 10000,
"learning_rate": 5e-4,
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 1000,
"fp16": true,
"logging_steps": 100,
"valid_size": 0.05,
"test_size": 0.05,
"num_samples": -1
}

View File

@ -1,118 +0,0 @@
import json
import logging
import multiprocessing
from pathlib import Path
from datasets import load_from_disk
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def load_node_types_from_json(json_path: Path):
"""
Load node types from the Tree-sitter grammar's `node_types.json` and include UNK as the 0 index.
Args:
json_path (Path): Path to the `node_types.json` file.
Returns:
dict: A mapping from node type strings to unique integer IDs.
"""
if not json_path.exists():
raise FileNotFoundError(f"{json_path} not found.")
logger.info(f"Loading node types from {json_path}...")
with open(json_path, "r", encoding="utf-8") as f:
node_types_data = json.load(f)
# Extract all unique "type" entries
node_types = set()
def extract_types(data):
if isinstance(data, list):
for item in data:
extract_types(item)
elif isinstance(data, dict):
if "type" in data and isinstance(data["type"], str):
node_types.add(data["type"])
for key, value in data.items():
extract_types(value)
extract_types(node_types_data)
# Create mapping and add 'UNK' at index 0
node_type2id = {"<UNK>": 0}
for i, node_type in enumerate(sorted(node_types), start=1):
node_type2id[node_type] = i
logger.info(f"Loaded {len(node_type2id)} node types, including UNK.")
return node_type2id
def encode_node_types(examples, node_type2id):
"""
Batched function to replace node type strings with their integer IDs using a preloaded mapping.
"""
encoded_node_types = []
for node_list in examples["node_types"]:
try:
encoded_node_list = [node_type2id[nt] if nt is not None and nt != 'ERROR' else node_type2id['<UNK>'] for nt in node_list]
encoded_node_types.append(encoded_node_list)
except KeyError as e:
raise KeyError(f"Unknown node type encountered: {e}")
examples["node_types_encoded"] = encoded_node_types
return examples
def main():
"""
Main script to load, process, and save a dataset with node types encoded as integers.
"""
# ------------------------------------------------------------------------------
# 1. Setup paths & load dataset
# ------------------------------------------------------------------------------
current_dir = Path(__file__).parent
input_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-padded"
output_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-encoded"
node_types_path = current_dir / "node_types.json"
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Loading dataset from {input_dir}...")
dataset = load_from_disk(str(input_dir))
logger.info("Dataset loaded.")
# Determine number of processes to use
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes.")
# ------------------------------------------------------------------------------
# 2. Load node types from JSON
# ------------------------------------------------------------------------------
node_type2id = load_node_types_from_json(node_types_path)
logger.info(f"Loaded {len(node_type2id)} node types.")
# Save node_type2id to disk
with open(output_dir / "node_type2id.json", "w") as f:
json.dump(node_type2id, f)
# ------------------------------------------------------------------------------
# 3. Convert node types in the dataset to integer IDs
# ------------------------------------------------------------------------------
logger.info("Converting node type strings to integer IDs...")
dataset = dataset.map(
lambda examples: encode_node_types(examples, node_type2id),
batched=True,
num_proc=num_proc,
desc="Encoding node types to integer IDs",
)
# ------------------------------------------------------------------------------
# 4. Save the modified dataset to disk
# ------------------------------------------------------------------------------
logger.info(f"Saving updated dataset to {output_dir}...")
dataset.save_to_disk(str(output_dir))
logger.info("Dataset saved successfully.")
if __name__ == "__main__":
main()

View File

@ -133,7 +133,7 @@ def main():
model_config.max_position_embeddings = 1024
if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(config=model_config, log=False)
model = TreeStarEncoderForPreTraining(config=model_config)
else:
model = AutoModelForMaskedLM.from_config(model_config)

File diff suppressed because it is too large Load Diff

View File

@ -1,77 +0,0 @@
import logging
from pathlib import Path
from datasets import load_from_disk
from transformers import AutoTokenizer
import multiprocessing
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def pad_and_save_dataset(input_dir, output_dir, tokenizer_name='bigcode/starencoder', max_length=512):
# Load the processed dataset
logger.info(f"Loading processed dataset from {input_dir}...")
dataset = load_from_disk(input_dir)
logger.info(f"Loaded dataset with {len(dataset)} examples")
# Initialize tokenizer
logger.info("Initializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loaded StarEncoder tokenizer")
# Define number of processes
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes")
# Define a function to pad the sequences
def pad_sequences(batch):
# Convert input_ids back to text if necessary
texts = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
# Use the tokenizer's __call__ method for padding
padded_inputs = tokenizer(
texts,
padding='max_length',
max_length=max_length,
return_tensors='pt',
truncation=True
)
# Pad other fields with default values
padded_depths = [seq + [-1] * (max_length - len(seq)) for seq in batch['depths']]
padded_sibling_idxs = [seq + [-1] * (max_length - len(seq)) for seq in batch['sibling_idxs']]
padded_node_types = [seq + [None] * (max_length - len(seq)) for seq in batch['node_types']]
padded_node_texts = [seq + [''] * (max_length - len(seq)) for seq in batch['node_texts']]
return {
'input_ids': padded_inputs['input_ids'].tolist(),
'attention_mask': padded_inputs['attention_mask'].tolist(),
'depths': padded_depths,
'sibling_idxs': padded_sibling_idxs,
'node_types': padded_node_types,
'node_texts': padded_node_texts
}
# Apply padding
logger.info("Applying padding to dataset...")
padded_dataset = dataset.map(
pad_sequences,
batched=True,
desc="Padding dataset",
num_proc=num_proc
)
# Save the padded dataset
logger.info(f"Saving padded dataset to {output_dir}...")
padded_dataset.save_to_disk(output_dir)
logger.info(f"Saved padded dataset to {output_dir}")
if __name__ == "__main__":
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes-padded'
pad_and_save_dataset(input_dir, output_dir)

View File

@ -84,7 +84,6 @@ def process_example(code, tokenizer):
depths = [-1] * len(input_ids)
sibling_idxs = [-1] * len(input_ids)
node_types = [None] * len(input_ids)
node_texts = [''] * len(input_ids)
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
@ -96,7 +95,6 @@ def process_example(code, tokenizer):
if node.start_byte <= start < node.end_byte:
depths[i] = depth
sibling_idxs[i] = sibling_idx
node_types[i] = node.type
node_texts[i] = code[node.start_byte:node.end_byte]
for i, child in enumerate(node.children):
traverse(child, depth + 1, i)
@ -109,7 +107,6 @@ def process_example(code, tokenizer):
'attention_mask': attention_mask,
'depths': depths,
'sibling_idxs': sibling_idxs,
'node_types': node_types,
'node_texts': node_texts
}
@ -121,7 +118,6 @@ def process_batch(batch, tokenizer):
processed_depths = []
processed_sibling_idxs = []
processed_node_texts = []
processed_node_types = []
for content in contents:
try:
@ -134,7 +130,6 @@ def process_batch(batch, tokenizer):
processed_depths.append([])
processed_sibling_idxs.append([])
processed_node_texts.append([])
processed_node_types.append([])
else:
processed_input_ids.append(result['input_ids'])
processed_attention_mask.append(result['attention_mask'])
@ -142,7 +137,6 @@ def process_batch(batch, tokenizer):
processed_depths.append(result['depths'])
processed_sibling_idxs.append(result['sibling_idxs'])
processed_node_texts.append(result['node_texts'])
processed_node_types.append(result['node_types'])
except Exception:
# If something unexpected happens
processed_input_ids.append([])
@ -151,7 +145,6 @@ def process_batch(batch, tokenizer):
processed_depths.append([])
processed_sibling_idxs.append([])
processed_node_texts.append([])
processed_node_types.append([])
return {
'input_ids': processed_input_ids,
@ -159,7 +152,6 @@ def process_batch(batch, tokenizer):
'tokens': processed_tokens,
'depths': processed_depths,
'sibling_idxs': processed_sibling_idxs,
'node_types': processed_node_types,
'node_texts': processed_node_texts
}

View File

@ -1,32 +1,44 @@
import os
import wandb
import json
import argparse
import yaml
import torch
import random
import logging
import numpy as np
from pathlib import Path
from datasets import load_from_disk, DatasetDict
from datasets import load_dataset, DatasetDict
from transformers import (
RobertaConfig,
AutoConfig,
RobertaForMaskedLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
AutoModelForMaskedLM
)
import random
import numpy as np
import torch
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MonitoringTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
# Perform the regular training step
outputs = super().training_step(model, inputs, num_items_in_batch)
# Log gradient norms
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm(2).item()
wandb.log({f'grad_norm/{name}': grad_norm})
# Log weight norms
for name, param in model.named_parameters():
weight_norm = param.data.norm(2).item()
wandb.log({f'weight_norm/{name}': weight_norm})
return outputs
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
@ -38,60 +50,57 @@ def set_seed(seed: int):
def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f:
return json.load(f)
return yaml.safe_load(f)
def main():
# Setup paths
current_dir = Path(__file__).parent
config = load_config(current_dir / 'config.json')
data_dir = Path(config['data_dir'])
output_dir = Path(config['output_dir'])
def initialize_wandb(config, name, files_to_save):
wandb.init(project=config['experiment']['wandb_project'], config=config, name=name)
for file in files_to_save:
wandb.save(file)
# Set seed
set_seed(config['seed'])
# Initialize W&B
wandb.init(project='codeparrot-starencoder-no-comments', config=config, name=config['run_name'])
# Upload the training files to W&B
wandb.save(__file__)
wandb.save(Path(__file__).parent / 'config.json')
if config['extra_embeddings']:
wandb.save(current_dir / 'tree_starencoder.py')
if 'CodeSearchNet' in config['data_dir']:
dataset = DatasetDict({
'train': load_from_disk(data_dir / 'train'),
'valid': load_from_disk(data_dir / 'valid'),
'test': load_from_disk(data_dir / 'test')
})
else:
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
def prepare_dataset(config, cache_dir):
dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir)
if config['data']['num_samples'] > 0:
dataset = dataset.select(range(config['data']['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size'])
test_valid = train_testvalid['test'].train_test_split(
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
seed=config['seed']
test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']),
seed=config['training']['seed']
)
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
# Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if config['extra_embeddings']:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
if config['model']['extra_embeddings']:
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
dataset = dataset.remove_columns(columns_to_remove)
logger.info(f'Loaded dataset:\n{dataset}')
return dataset
def main():
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
args = parser.parse_args()
current_dir = Path(__file__).parent
config_path = Path(args.config)
config = load_config(config_path)
cache_dir = Path(config['experiment']['cache_dir'])
output_dir = Path('./outputs') / config_path.stem
if config['experiment']['use_wandb']:
os.environ['WANDB_MODE'] = 'online'
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
else:
os.environ['WANDB_MODE'] = 'offline'
logger.info('Wandb is not used.')
set_seed(config['training']['seed'])
dataset = prepare_dataset(config, cache_dir)
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
logger.info(f'Dataset columns: {dataset["train"].column_names}')
# Initialize model from scratch
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
if tokenizer.mask_token is None:
tokenizer.add_special_tokens({'mask_token': '<mask>'})
@ -101,41 +110,40 @@ def main():
tokenizer.pad_token = tokenizer.eos_token
logger.info("Set padding token to be the same as the EOS token.")
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config)
model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir)
if config['model']['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config, yaml_config=config)
else:
model = AutoModelForMaskedLM.from_config(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}')
# Setup training arguments
training_args = TrainingArguments(
output_dir=str(output_dir),
per_device_train_batch_size=config['batch_size'],
per_device_eval_batch_size=config['batch_size'],
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
num_train_epochs=config['epochs'],
warmup_steps=config['warmup_steps'],
max_grad_norm=config['max_grad_norm'],
logging_steps=config['logging_steps'],
eval_steps=config['eval_every'],
save_steps=config['eval_every'],
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'],
learning_rate=config['training']['learning_rate'],
weight_decay=config['training']['weight_decay'],
num_train_epochs=config['training']['epochs'],
warmup_steps=config['training']['warmup_steps'],
max_grad_norm=config['training']['max_grad_norm'],
logging_steps=config['evaluation']['logging_steps'],
eval_steps=config['evaluation']['eval_every'],
save_steps=config['evaluation']['eval_every'],
eval_strategy='steps',
save_strategy='steps',
load_best_model_at_end=True,
report_to='wandb',
run_name=config['run_name'],
seed=config['seed'],
fp16=config['fp16'],
report_to='wandb' if config['experiment']['use_wandb'] else None,
run_name=config_path.stem,
seed=config['training']['seed'],
fp16=config['training']['fp16'],
dataloader_num_workers=8,
gradient_checkpointing=True,
metric_for_best_model='eval_loss',
greater_is_better=False,
save_total_limit=3,
)
# Create trainer
trainer = Trainer(
trainer = MonitoringTrainer(
model=model,
args=training_args,
train_dataset=dataset['train'],
@ -143,26 +151,25 @@ def main():
data_collator=DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config['mlm_probability']
mlm_probability=config['data']['mlm_probability']
),
)
# Train
logger.info('Starting training...')
trainer.train()
logger.info('Training completed.')
# Evaluate
logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(dataset['test'])
logger.info(f'Evaluation results: {eval_results}')
wandb.log({'test_loss': eval_results['eval_loss']})
logger.info(f'Test loss: {eval_results["eval_loss"]}')
# Save final model
logger.info('Saving final model...')
trainer.save_model(output_dir / 'final-model')
tokenizer.save_pretrained(output_dir / 'final-model')
logger.info('Training completed!')
# Upload to W&B
wandb.save(output_dir / 'final-model/*')
if __name__ == '__main__':
main()

View File

@ -1,208 +0,0 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import RobertaConfig, RobertaForMaskedLM
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
wandb.log(
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
step=kwargs.get("global_step", None)
)
wandb.log(
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
step=kwargs.get("global_step", None)
)
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}

View File

@ -1,47 +1,106 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import AutoConfig, BertForMaskedLM
from transformers import AutoConfig, BertForMaskedLM, GenerationMixin
from tree_codebert import TreePositionalEmbedding
class TreePositionalEmbedding(nn.Module):
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
class TreeStarEncoderForPreTraining(BertForMaskedLM):
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.config = config
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
# self.fusion_layer = nn.Sequential(
# nn.Linear(config.hidden_size * 4, config.hidden_size),
# nn.GELU(),
# nn.Dropout(config.hidden_dropout_prob),
# nn.LayerNorm(config.hidden_size)
# )
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
depth_embeddings = self.depth_embedding(depths)
sibling_embeddings = self.sibling_embedding(sibling_idxs)
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
return self.layer_norm(embeddings)
# Override config to set max_seq_length
config.max_position_embeddings = max_seq_length
class NewEmbeddings(nn.Module):
"""Construct the embeddings from word, position and tree-based embeddings.
"""
def __init__(self, config, yaml_config):
super().__init__()
self.yaml_config = yaml_config
if self.yaml_config['model']['concat_embeddings']:
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
nn.GELU(),
nn.Linear(config.hidden_size * 3, config.hidden_size),
)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
max_depth=yaml_config['model']['max_depth'],
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# # Initialize sequential position embeddings with sinusoidal pattern
# position = torch.arange(max_seq_length).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
# pe = torch.zeros(max_seq_length, config.hidden_size)
# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# self.seq_pos_embeddings.weight.data.copy_(pe)
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
input_shape = input_ids.size()
# New node type embeddings
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size)
seq_length = input_shape[1]
device = input_ids.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
if self.yaml_config['model']['sum_embeddings']:
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
if self.yaml_config['model']['concat_embeddings']:
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
embeddings = self.fusion_layer(embeddings)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
def __init__(self, config: AutoConfig, yaml_config: Dict):
super().__init__(config)
self.config = config
self.yaml_config = yaml_config
self.embeddings = NewEmbeddings(config, yaml_config)
# self.fusion_layer = nn.Sequential(
# nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
# nn.GELU(),
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
# nn.Dropout(config.hidden_dropout_prob),
# nn.LayerNorm(config.hidden_size)
# )
def forward(
self,
@ -53,34 +112,10 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get token embeddings
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# # Get node type embeddings
# node_type_embeddings = self.node_type_embeddings(node_types)
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
# combined_embeddings = self.fusion_layer(combined)
# Add the embeddings instead of concatenating
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
combined_embeddings = self.norm(combined_embeddings)
embedding_output = self.embeddings(input_ids, depths, sibling_idxs)
outputs = self.bert(
inputs_embeds=combined_embeddings,
inputs_embeds=embedding_output,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
@ -100,6 +135,7 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
return {
"loss": masked_lm_loss,
"logits": prediction_scores,
"last_hidden_state": sequence_output,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}