Compare commits

...

9 Commits

Author SHA1 Message Date
Patryk Bartkowiak
a6ebe812cf added more monitoring to the training 2025-01-21 10:56:16 +00:00
Patryk Bartkowiak
19864729cf max 3 checkpoints 2025-01-20 23:46:03 +00:00
Patryk Bartkowiak
2639e8dca2 changed configs for runpod 2025-01-20 21:14:20 +00:00
Patryk Bartkowiak
2ec7cc263d updated README 2025-01-16 17:06:36 +00:00
Patryk Bartkowiak
03b2502af0 simplified all the files and prepared config for each experiment 2025-01-16 17:03:53 +00:00
Patryk Bartkowiak
3b93a7cc8a prepared for run by prof Filip Gralinski (done) 2025-01-07 15:59:30 +00:00
Patryk Bartkowiak
76a89dc236 3 pochs 2025-01-07 12:58:20 +00:00
Patryk Bartkowiak
eed2096400 prepared for run by prof Filip Gralinski 2025-01-07 11:45:55 +00:00
Patryk Bartkowiak
f0679ab861 on this commit i continued to train original starencoder model 2025-01-04 21:02:30 +00:00
24 changed files with 401 additions and 4346 deletions

1
code/.gitignore vendored
View File

@ -164,3 +164,4 @@ cython_debug/
# Weights & Biases # Weights & Biases
wandb/ wandb/
outputs/ outputs/
cache/

View File

@ -20,9 +20,9 @@ pdm install
``` ```
### 4. Run training code ### 4. Run training code
```bash ```bash
pdm run_training pdm train --config {CONFIG FILE}
``` ```
or Example:
``` ```bash
pdm run src/train_codebert_mlm.py pdm train --config ./configs/original.yaml
``` ```

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: false
wandb_project: "prof-gralinski"
training:
seed: 42
batch_size: 32
epochs: 1
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 0
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/1-fold-clone-detection-600k-5fold"
num_samples: 1000 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
path: "outputs/tmp/final-model"
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: false # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: false
sum_embeddings: true
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: true
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -0,0 +1,33 @@
experiment:
cache_dir: "./cache/"
use_wandb: true
wandb_project: "runpod"
training:
seed: 42
batch_size: 32
epochs: 3
learning_rate: 0.0005
weight_decay: 0.1
max_grad_norm: 1.0
warmup_steps: 1000
fp16: true
evaluation:
eval_every: 10000
logging_steps: 100
data:
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
mlm_probability: 0.15
num_samples: -1 # -1 means all samples
valid_size: 0.05
test_size: 0.05
model:
extra_embeddings: true # if false, use the original star encoder
sinusoidal_init: false
concat_embeddings: true
sum_embeddings: false
max_depth: 32
max_seq_length: 512

View File

@ -1,2 +0,0 @@
*
!.gitignore

View File

@ -1,2 +0,0 @@
*
!.gitignore

View File

@ -19,6 +19,7 @@ dependencies = [
"tree-sitter-python==0.23.4", "tree-sitter-python==0.23.4",
"ipykernel==6.29.5", "ipykernel==6.29.5",
"ipywidgets==8.1.5", "ipywidgets==8.1.5",
"pyyaml==6.0.2",
] ]
requires-python = "==3.11.*" requires-python = "==3.11.*"
readme = "README.md" readme = "README.md"
@ -41,3 +42,4 @@ distribution = true
parse_dataset = {cmd = "src/parse_dataset.py"} parse_dataset = {cmd = "src/parse_dataset.py"}
train = {cmd = "src/training.py"} train = {cmd = "src/training.py"}
eval = {cmd = "src/eval_model.py"} eval = {cmd = "src/eval_model.py"}
finetune = {cmd = "src_finetune/finetune.py"}

View File

@ -1,20 +0,0 @@
{
"extra_embeddings": true,
"run_name": "no-sinusoidal",
"data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/",
"output_dir": "./outputs/long-no-comments-starencoder-no-sinusoidal",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,
"epochs": 3,
"eval_every": 10000,
"learning_rate": 5e-4,
"weight_decay": 0.1,
"max_grad_norm": 1.0,
"warmup_steps": 1000,
"fp16": true,
"logging_steps": 100,
"valid_size": 0.05,
"test_size": 0.05,
"num_samples": -1
}

View File

@ -1,118 +0,0 @@
import json
import logging
import multiprocessing
from pathlib import Path
from datasets import load_from_disk
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def load_node_types_from_json(json_path: Path):
"""
Load node types from the Tree-sitter grammar's `node_types.json` and include UNK as the 0 index.
Args:
json_path (Path): Path to the `node_types.json` file.
Returns:
dict: A mapping from node type strings to unique integer IDs.
"""
if not json_path.exists():
raise FileNotFoundError(f"{json_path} not found.")
logger.info(f"Loading node types from {json_path}...")
with open(json_path, "r", encoding="utf-8") as f:
node_types_data = json.load(f)
# Extract all unique "type" entries
node_types = set()
def extract_types(data):
if isinstance(data, list):
for item in data:
extract_types(item)
elif isinstance(data, dict):
if "type" in data and isinstance(data["type"], str):
node_types.add(data["type"])
for key, value in data.items():
extract_types(value)
extract_types(node_types_data)
# Create mapping and add 'UNK' at index 0
node_type2id = {"<UNK>": 0}
for i, node_type in enumerate(sorted(node_types), start=1):
node_type2id[node_type] = i
logger.info(f"Loaded {len(node_type2id)} node types, including UNK.")
return node_type2id
def encode_node_types(examples, node_type2id):
"""
Batched function to replace node type strings with their integer IDs using a preloaded mapping.
"""
encoded_node_types = []
for node_list in examples["node_types"]:
try:
encoded_node_list = [node_type2id[nt] if nt is not None and nt != 'ERROR' else node_type2id['<UNK>'] for nt in node_list]
encoded_node_types.append(encoded_node_list)
except KeyError as e:
raise KeyError(f"Unknown node type encountered: {e}")
examples["node_types_encoded"] = encoded_node_types
return examples
def main():
"""
Main script to load, process, and save a dataset with node types encoded as integers.
"""
# ------------------------------------------------------------------------------
# 1. Setup paths & load dataset
# ------------------------------------------------------------------------------
current_dir = Path(__file__).parent
input_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-padded"
output_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-encoded"
node_types_path = current_dir / "node_types.json"
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Loading dataset from {input_dir}...")
dataset = load_from_disk(str(input_dir))
logger.info("Dataset loaded.")
# Determine number of processes to use
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes.")
# ------------------------------------------------------------------------------
# 2. Load node types from JSON
# ------------------------------------------------------------------------------
node_type2id = load_node_types_from_json(node_types_path)
logger.info(f"Loaded {len(node_type2id)} node types.")
# Save node_type2id to disk
with open(output_dir / "node_type2id.json", "w") as f:
json.dump(node_type2id, f)
# ------------------------------------------------------------------------------
# 3. Convert node types in the dataset to integer IDs
# ------------------------------------------------------------------------------
logger.info("Converting node type strings to integer IDs...")
dataset = dataset.map(
lambda examples: encode_node_types(examples, node_type2id),
batched=True,
num_proc=num_proc,
desc="Encoding node types to integer IDs",
)
# ------------------------------------------------------------------------------
# 4. Save the modified dataset to disk
# ------------------------------------------------------------------------------
logger.info(f"Saving updated dataset to {output_dir}...")
dataset.save_to_disk(str(output_dir))
logger.info("Dataset saved successfully.")
if __name__ == "__main__":
main()

View File

@ -133,7 +133,7 @@ def main():
model_config.max_position_embeddings = 1024 model_config.max_position_embeddings = 1024
if config['extra_embeddings']: if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(config=model_config, log=False) model = TreeStarEncoderForPreTraining(config=model_config)
else: else:
model = AutoModelForMaskedLM.from_config(model_config) model = AutoModelForMaskedLM.from_config(model_config)

File diff suppressed because it is too large Load Diff

View File

@ -1,77 +0,0 @@
import logging
from pathlib import Path
from datasets import load_from_disk
from transformers import AutoTokenizer
import multiprocessing
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def pad_and_save_dataset(input_dir, output_dir, tokenizer_name='bigcode/starencoder', max_length=512):
# Load the processed dataset
logger.info(f"Loading processed dataset from {input_dir}...")
dataset = load_from_disk(input_dir)
logger.info(f"Loaded dataset with {len(dataset)} examples")
# Initialize tokenizer
logger.info("Initializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loaded StarEncoder tokenizer")
# Define number of processes
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes")
# Define a function to pad the sequences
def pad_sequences(batch):
# Convert input_ids back to text if necessary
texts = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
# Use the tokenizer's __call__ method for padding
padded_inputs = tokenizer(
texts,
padding='max_length',
max_length=max_length,
return_tensors='pt',
truncation=True
)
# Pad other fields with default values
padded_depths = [seq + [-1] * (max_length - len(seq)) for seq in batch['depths']]
padded_sibling_idxs = [seq + [-1] * (max_length - len(seq)) for seq in batch['sibling_idxs']]
padded_node_types = [seq + [None] * (max_length - len(seq)) for seq in batch['node_types']]
padded_node_texts = [seq + [''] * (max_length - len(seq)) for seq in batch['node_texts']]
return {
'input_ids': padded_inputs['input_ids'].tolist(),
'attention_mask': padded_inputs['attention_mask'].tolist(),
'depths': padded_depths,
'sibling_idxs': padded_sibling_idxs,
'node_types': padded_node_types,
'node_texts': padded_node_texts
}
# Apply padding
logger.info("Applying padding to dataset...")
padded_dataset = dataset.map(
pad_sequences,
batched=True,
desc="Padding dataset",
num_proc=num_proc
)
# Save the padded dataset
logger.info(f"Saving padded dataset to {output_dir}...")
padded_dataset.save_to_disk(output_dir)
logger.info(f"Saved padded dataset to {output_dir}")
if __name__ == "__main__":
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes-padded'
pad_and_save_dataset(input_dir, output_dir)

View File

@ -84,7 +84,6 @@ def process_example(code, tokenizer):
depths = [-1] * len(input_ids) depths = [-1] * len(input_ids)
sibling_idxs = [-1] * len(input_ids) sibling_idxs = [-1] * len(input_ids)
node_types = [None] * len(input_ids)
node_texts = [''] * len(input_ids) node_texts = [''] * len(input_ids)
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids) tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
@ -96,7 +95,6 @@ def process_example(code, tokenizer):
if node.start_byte <= start < node.end_byte: if node.start_byte <= start < node.end_byte:
depths[i] = depth depths[i] = depth
sibling_idxs[i] = sibling_idx sibling_idxs[i] = sibling_idx
node_types[i] = node.type
node_texts[i] = code[node.start_byte:node.end_byte] node_texts[i] = code[node.start_byte:node.end_byte]
for i, child in enumerate(node.children): for i, child in enumerate(node.children):
traverse(child, depth + 1, i) traverse(child, depth + 1, i)
@ -109,7 +107,6 @@ def process_example(code, tokenizer):
'attention_mask': attention_mask, 'attention_mask': attention_mask,
'depths': depths, 'depths': depths,
'sibling_idxs': sibling_idxs, 'sibling_idxs': sibling_idxs,
'node_types': node_types,
'node_texts': node_texts 'node_texts': node_texts
} }
@ -121,7 +118,6 @@ def process_batch(batch, tokenizer):
processed_depths = [] processed_depths = []
processed_sibling_idxs = [] processed_sibling_idxs = []
processed_node_texts = [] processed_node_texts = []
processed_node_types = []
for content in contents: for content in contents:
try: try:
@ -134,7 +130,6 @@ def process_batch(batch, tokenizer):
processed_depths.append([]) processed_depths.append([])
processed_sibling_idxs.append([]) processed_sibling_idxs.append([])
processed_node_texts.append([]) processed_node_texts.append([])
processed_node_types.append([])
else: else:
processed_input_ids.append(result['input_ids']) processed_input_ids.append(result['input_ids'])
processed_attention_mask.append(result['attention_mask']) processed_attention_mask.append(result['attention_mask'])
@ -142,7 +137,6 @@ def process_batch(batch, tokenizer):
processed_depths.append(result['depths']) processed_depths.append(result['depths'])
processed_sibling_idxs.append(result['sibling_idxs']) processed_sibling_idxs.append(result['sibling_idxs'])
processed_node_texts.append(result['node_texts']) processed_node_texts.append(result['node_texts'])
processed_node_types.append(result['node_types'])
except Exception: except Exception:
# If something unexpected happens # If something unexpected happens
processed_input_ids.append([]) processed_input_ids.append([])
@ -151,7 +145,6 @@ def process_batch(batch, tokenizer):
processed_depths.append([]) processed_depths.append([])
processed_sibling_idxs.append([]) processed_sibling_idxs.append([])
processed_node_texts.append([]) processed_node_texts.append([])
processed_node_types.append([])
return { return {
'input_ids': processed_input_ids, 'input_ids': processed_input_ids,
@ -159,7 +152,6 @@ def process_batch(batch, tokenizer):
'tokens': processed_tokens, 'tokens': processed_tokens,
'depths': processed_depths, 'depths': processed_depths,
'sibling_idxs': processed_sibling_idxs, 'sibling_idxs': processed_sibling_idxs,
'node_types': processed_node_types,
'node_texts': processed_node_texts 'node_texts': processed_node_texts
} }

View File

@ -1,32 +1,44 @@
import os
import wandb import wandb
import json import argparse
import yaml
import torch
import random
import logging import logging
import numpy as np
from pathlib import Path from pathlib import Path
from datasets import load_from_disk, DatasetDict from datasets import load_dataset, DatasetDict
from transformers import ( from transformers import (
RobertaConfig,
AutoConfig, AutoConfig,
RobertaForMaskedLM,
AutoTokenizer, AutoTokenizer,
TrainingArguments, TrainingArguments,
Trainer, Trainer,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
AutoModelForMaskedLM AutoModelForMaskedLM
) )
import random
import numpy as np
import torch
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig( logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MonitoringTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
# Perform the regular training step
outputs = super().training_step(model, inputs, num_items_in_batch)
# Log gradient norms
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm(2).item()
wandb.log({f'grad_norm/{name}': grad_norm})
# Log weight norms
for name, param in model.named_parameters():
weight_norm = param.data.norm(2).item()
wandb.log({f'weight_norm/{name}': weight_norm})
return outputs
def set_seed(seed: int): def set_seed(seed: int):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
@ -38,60 +50,57 @@ def set_seed(seed: int):
def load_config(config_path: Path) -> dict: def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
return json.load(f) return yaml.safe_load(f)
def main(): def initialize_wandb(config, name, files_to_save):
# Setup paths wandb.init(project=config['experiment']['wandb_project'], config=config, name=name)
current_dir = Path(__file__).parent for file in files_to_save:
config = load_config(current_dir / 'config.json') wandb.save(file)
data_dir = Path(config['data_dir'])
output_dir = Path(config['output_dir'])
# Set seed def prepare_dataset(config, cache_dir):
set_seed(config['seed']) dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir)
if config['data']['num_samples'] > 0:
# Initialize W&B dataset = dataset.select(range(config['data']['num_samples']))
wandb.init(project='codeparrot-starencoder-no-comments', config=config, name=config['run_name']) train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size'])
test_valid = train_testvalid['test'].train_test_split(
# Upload the training files to W&B test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']),
wandb.save(__file__) seed=config['training']['seed']
wandb.save(Path(__file__).parent / 'config.json') )
if config['extra_embeddings']: dataset = DatasetDict({
wandb.save(current_dir / 'tree_starencoder.py') 'train': train_testvalid['train'],
'test': test_valid['test'],
if 'CodeSearchNet' in config['data_dir']: 'valid': test_valid['train'],
dataset = DatasetDict({ })
'train': load_from_disk(data_dir / 'train'), columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
'valid': load_from_disk(data_dir / 'valid'), if config['model']['extra_embeddings']:
'test': load_from_disk(data_dir / 'test') columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
})
else:
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split(
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
seed=config['seed']
)
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
# Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
if config['extra_embeddings']:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
dataset = dataset.remove_columns(columns_to_remove) dataset = dataset.remove_columns(columns_to_remove)
logger.info(f'Loaded dataset:\n{dataset}') return dataset
# Initialize model from scratch def main():
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
args = parser.parse_args()
current_dir = Path(__file__).parent
config_path = Path(args.config)
config = load_config(config_path)
cache_dir = Path(config['experiment']['cache_dir'])
output_dir = Path('./outputs') / config_path.stem
if config['experiment']['use_wandb']:
os.environ['WANDB_MODE'] = 'online'
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
else:
os.environ['WANDB_MODE'] = 'offline'
logger.info('Wandb is not used.')
set_seed(config['training']['seed'])
dataset = prepare_dataset(config, cache_dir)
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
logger.info(f'Dataset columns: {dataset["train"].column_names}')
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder') tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
if tokenizer.mask_token is None: if tokenizer.mask_token is None:
tokenizer.add_special_tokens({'mask_token': '<mask>'}) tokenizer.add_special_tokens({'mask_token': '<mask>'})
@ -101,41 +110,40 @@ def main():
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Set padding token to be the same as the EOS token.") logger.info("Set padding token to be the same as the EOS token.")
model_config = AutoConfig.from_pretrained('bigcode/starencoder') model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir)
if config['extra_embeddings']: if config['model']['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config) model = TreeStarEncoderForPreTraining(model_config, yaml_config=config)
else: else:
model = AutoModelForMaskedLM.from_config(model_config) model = AutoModelForMaskedLM.from_config(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}') logger.info(f'Loaded model: {model.__class__.__name__}')
# Setup training arguments
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=str(output_dir), output_dir=str(output_dir),
per_device_train_batch_size=config['batch_size'], per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['batch_size'], per_device_eval_batch_size=config['training']['batch_size'],
learning_rate=config['learning_rate'], learning_rate=config['training']['learning_rate'],
weight_decay=config['weight_decay'], weight_decay=config['training']['weight_decay'],
num_train_epochs=config['epochs'], num_train_epochs=config['training']['epochs'],
warmup_steps=config['warmup_steps'], warmup_steps=config['training']['warmup_steps'],
max_grad_norm=config['max_grad_norm'], max_grad_norm=config['training']['max_grad_norm'],
logging_steps=config['logging_steps'], logging_steps=config['evaluation']['logging_steps'],
eval_steps=config['eval_every'], eval_steps=config['evaluation']['eval_every'],
save_steps=config['eval_every'], save_steps=config['evaluation']['eval_every'],
eval_strategy='steps', eval_strategy='steps',
save_strategy='steps', save_strategy='steps',
load_best_model_at_end=True, load_best_model_at_end=True,
report_to='wandb', report_to='wandb' if config['experiment']['use_wandb'] else None,
run_name=config['run_name'], run_name=config_path.stem,
seed=config['seed'], seed=config['training']['seed'],
fp16=config['fp16'], fp16=config['training']['fp16'],
dataloader_num_workers=8, dataloader_num_workers=8,
gradient_checkpointing=True, gradient_checkpointing=True,
metric_for_best_model='eval_loss', metric_for_best_model='eval_loss',
greater_is_better=False, greater_is_better=False,
save_total_limit=3,
) )
# Create trainer trainer = MonitoringTrainer(
trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=dataset['train'], train_dataset=dataset['train'],
@ -143,26 +151,25 @@ def main():
data_collator=DataCollatorForLanguageModeling( data_collator=DataCollatorForLanguageModeling(
tokenizer=tokenizer, tokenizer=tokenizer,
mlm=True, mlm=True,
mlm_probability=config['mlm_probability'] mlm_probability=config['data']['mlm_probability']
), ),
) )
# Train
logger.info('Starting training...') logger.info('Starting training...')
trainer.train() trainer.train()
logger.info('Training completed.')
# Evaluate
logger.info('Evaluating on test set...') logger.info('Evaluating on test set...')
eval_results = trainer.evaluate(dataset['test']) eval_results = trainer.evaluate(dataset['test'])
logger.info(f'Evaluation results: {eval_results}')
wandb.log({'test_loss': eval_results['eval_loss']}) wandb.log({'test_loss': eval_results['eval_loss']})
logger.info(f'Test loss: {eval_results["eval_loss"]}')
# Save final model
logger.info('Saving final model...') logger.info('Saving final model...')
trainer.save_model(output_dir / 'final-model') trainer.save_model(output_dir / 'final-model')
tokenizer.save_pretrained(output_dir / 'final-model') tokenizer.save_pretrained(output_dir / 'final-model')
logger.info('Training completed!') # Upload to W&B
wandb.save(output_dir / 'final-model/*')
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,208 +0,0 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import RobertaConfig, RobertaForMaskedLM
class TreePositionalEmbedding(nn.Module):
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
# Separate embeddings for different features
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
# Improved projection layers
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
# Layer norm for stability
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
# Initialize projection layers
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
"""
Args:
depths: Tensor of shape [batch_size, seq_len] containing depth values
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
Returns:
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
"""
# Clamp values to max_depth
depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
# Get embeddings for each feature
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
# Combine features
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
# Apply layer norm
normalized_embeddings = self.layer_norm(embeddings)
return normalized_embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with tree-structural information."""
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize embedding weights equally
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
# Layer norms for embedding combination
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
def get_normalized_weights(self):
"""Get softmaxed weights for embedding combination."""
return torch.softmax(self.embedding_weights, dim=0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get normalized weights for embedding combination and calculate regularization
weights = self.get_normalized_weights()
# Calculate weight variance regularization
# We want weights to remain somewhat balanced, so penalize high variance
weight_variance = torch.var(weights)
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
# Add L2 regularization to prevent any weight from getting too close to 1
# This helps maintain a more balanced contribution from each embedding type
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
# Get token embeddings
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
token_embeddings = self.pre_combination_norm(token_embeddings)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings if tree information is provided
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# Combine all embeddings using learned weights
combined_embeddings = (
weights[0] * token_embeddings +
weights[1] * tree_embeddings +
weights[2] * seq_embeddings
)
combined_embeddings = self.post_combination_norm(combined_embeddings)
# Forward pass through base model
outputs = self.roberta(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Calculate MLM loss if labels are provided
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
# Add regularization losses to final loss
if masked_lm_loss is not None:
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
else:
final_loss = weight_reg_loss + l2_reg_loss
else:
final_loss = None
# Prepare embedding weights for logging
weights_cpu = weights.detach().cpu()
embedding_weights = {
"token": weights_cpu[0].item(),
"tree": weights_cpu[1].item(),
"sequential": weights_cpu[2].item()
}
reg_metrics = {
"weight_variance": weight_variance.item(),
"max_weight_penalty": max_weight_penalty.item(),
"weight_reg_loss": weight_reg_loss.item(),
"l2_reg_loss": l2_reg_loss.item()
}
wandb.log(
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
step=kwargs.get("global_step", None)
)
wandb.log(
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
step=kwargs.get("global_step", None)
)
return {
"loss": final_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}

View File

@ -1,48 +1,107 @@
import wandb
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict, Optional from typing import Dict, Optional
from transformers import AutoConfig, BertForMaskedLM from transformers import AutoConfig, BertForMaskedLM, GenerationMixin
from tree_codebert import TreePositionalEmbedding class TreePositionalEmbedding(nn.Module):
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_depth = max_depth
self.depth_embedding = nn.Embedding(max_depth, d_model)
self.sibling_embedding = nn.Embedding(max_depth, d_model)
self.node_projection = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(dropout)
)
self.layer_norm = nn.LayerNorm(d_model)
self._initialize_embeddings()
def _initialize_embeddings(self):
std = 0.02
for embedding in [self.depth_embedding, self.sibling_embedding]:
nn.init.normal_(embedding.weight, mean=0.0, std=std)
for layer in self.node_projection:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0.0, std=std)
nn.init.zeros_(layer.bias)
class TreeStarEncoderForPreTraining(BertForMaskedLM): def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512): depths = torch.clamp(depths, 0, self.max_depth - 1)
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
depth_embeddings = self.depth_embedding(depths)
sibling_embeddings = self.sibling_embedding(sibling_idxs)
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
embeddings = self.node_projection(combined)
return self.layer_norm(embeddings)
class NewEmbeddings(nn.Module):
"""Construct the embeddings from word, position and tree-based embeddings.
"""
def __init__(self, config, yaml_config):
super().__init__()
self.yaml_config = yaml_config
if self.yaml_config['model']['concat_embeddings']:
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
nn.GELU(),
nn.Linear(config.hidden_size * 3, config.hidden_size),
)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=yaml_config['model']['max_depth'],
dropout=config.hidden_dropout_prob
)
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
input_shape = input_ids.size()
seq_length = input_shape[1]
device = input_ids.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
if self.yaml_config['model']['sum_embeddings']:
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
if self.yaml_config['model']['concat_embeddings']:
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
embeddings = self.fusion_layer(embeddings)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
def __init__(self, config: AutoConfig, yaml_config: Dict):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.yaml_config = yaml_config
self.embeddings = NewEmbeddings(config, yaml_config)
# self.fusion_layer = nn.Sequential( # self.fusion_layer = nn.Sequential(
# nn.Linear(config.hidden_size * 4, config.hidden_size), # nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
# nn.GELU(), # nn.GELU(),
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
# nn.Dropout(config.hidden_dropout_prob), # nn.Dropout(config.hidden_dropout_prob),
# nn.LayerNorm(config.hidden_size) # nn.LayerNorm(config.hidden_size)
# ) # )
# Override config to set max_seq_length
config.max_position_embeddings = max_seq_length
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# # Initialize sequential position embeddings with sinusoidal pattern
# position = torch.arange(max_seq_length).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
# pe = torch.zeros(max_seq_length, config.hidden_size)
# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# self.seq_pos_embeddings.weight.data.copy_(pe)
# New node type embeddings
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -53,34 +112,10 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
output_attentions: bool = False, output_attentions: bool = False,
**kwargs **kwargs
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
device = input_ids.device embedding_output = self.embeddings(input_ids, depths, sibling_idxs)
# Get token embeddings
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
# # Get node type embeddings
# node_type_embeddings = self.node_type_embeddings(node_types)
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
# combined_embeddings = self.fusion_layer(combined)
# Add the embeddings instead of concatenating
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
combined_embeddings = self.norm(combined_embeddings)
outputs = self.bert( outputs = self.bert(
inputs_embeds=combined_embeddings, inputs_embeds=embedding_output,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs **kwargs
@ -100,6 +135,7 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
return { return {
"loss": masked_lm_loss, "loss": masked_lm_loss,
"logits": prediction_scores, "logits": prediction_scores,
"last_hidden_state": sequence_output,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None, "hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None, "attentions": outputs.attentions if output_attentions else None,
} }