Compare commits
3 Commits
master
...
runpod-exp
Author | SHA1 | Date | |
---|---|---|---|
|
35e5d3e8fa | ||
|
3d6826f058 | ||
|
dfb1e669bd |
1
code/.gitignore
vendored
1
code/.gitignore
vendored
@ -164,4 +164,3 @@ cython_debug/
|
|||||||
# Weights & Biases
|
# Weights & Biases
|
||||||
wandb/
|
wandb/
|
||||||
outputs/
|
outputs/
|
||||||
cache/
|
|
||||||
|
@ -20,9 +20,9 @@ pdm install
|
|||||||
```
|
```
|
||||||
### 4. Run training code
|
### 4. Run training code
|
||||||
```bash
|
```bash
|
||||||
pdm train --config {CONFIG FILE}
|
pdm run_training
|
||||||
```
|
```
|
||||||
Example:
|
or
|
||||||
```bash
|
```
|
||||||
pdm train --config ./configs/original.yaml
|
pdm run src/train_codebert_mlm.py
|
||||||
```
|
```
|
@ -1,33 +0,0 @@
|
|||||||
experiment:
|
|
||||||
cache_dir: "./cache/"
|
|
||||||
use_wandb: false
|
|
||||||
wandb_project: "prof-gralinski"
|
|
||||||
|
|
||||||
training:
|
|
||||||
seed: 42
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 1
|
|
||||||
learning_rate: 0.0005
|
|
||||||
weight_decay: 0.1
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
warmup_steps: 0
|
|
||||||
fp16: true
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
eval_every: 10000
|
|
||||||
logging_steps: 100
|
|
||||||
|
|
||||||
data:
|
|
||||||
source: "patrykbart/1-fold-clone-detection-600k-5fold"
|
|
||||||
num_samples: 1000 # -1 means all samples
|
|
||||||
valid_size: 0.05
|
|
||||||
test_size: 0.05
|
|
||||||
|
|
||||||
model:
|
|
||||||
path: "outputs/tmp/final-model"
|
|
||||||
extra_embeddings: true # if false, use the original star encoder
|
|
||||||
sinusoidal_init: false
|
|
||||||
concat_embeddings: true
|
|
||||||
sum_embeddings: false
|
|
||||||
max_depth: 32
|
|
||||||
max_seq_length: 512
|
|
@ -1,33 +0,0 @@
|
|||||||
experiment:
|
|
||||||
cache_dir: "./cache/"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_project: "runpod"
|
|
||||||
|
|
||||||
training:
|
|
||||||
seed: 42
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 3
|
|
||||||
learning_rate: 0.0005
|
|
||||||
weight_decay: 0.1
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
warmup_steps: 1000
|
|
||||||
fp16: true
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
eval_every: 10000
|
|
||||||
logging_steps: 100
|
|
||||||
|
|
||||||
data:
|
|
||||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
|
||||||
mlm_probability: 0.15
|
|
||||||
num_samples: -1 # -1 means all samples
|
|
||||||
valid_size: 0.05
|
|
||||||
test_size: 0.05
|
|
||||||
|
|
||||||
model:
|
|
||||||
extra_embeddings: false # if false, use the original star encoder
|
|
||||||
sinusoidal_init: false
|
|
||||||
concat_embeddings: false
|
|
||||||
sum_embeddings: false
|
|
||||||
max_depth: 32
|
|
||||||
max_seq_length: 512
|
|
@ -1,33 +0,0 @@
|
|||||||
experiment:
|
|
||||||
cache_dir: "./cache/"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_project: "runpod"
|
|
||||||
|
|
||||||
training:
|
|
||||||
seed: 42
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 3
|
|
||||||
learning_rate: 0.0005
|
|
||||||
weight_decay: 0.1
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
warmup_steps: 1000
|
|
||||||
fp16: true
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
eval_every: 10000
|
|
||||||
logging_steps: 100
|
|
||||||
|
|
||||||
data:
|
|
||||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
|
||||||
mlm_probability: 0.15
|
|
||||||
num_samples: -1 # -1 means all samples
|
|
||||||
valid_size: 0.05
|
|
||||||
test_size: 0.05
|
|
||||||
|
|
||||||
model:
|
|
||||||
extra_embeddings: true # if false, use the original star encoder
|
|
||||||
sinusoidal_init: true
|
|
||||||
concat_embeddings: false
|
|
||||||
sum_embeddings: true
|
|
||||||
max_depth: 32
|
|
||||||
max_seq_length: 512
|
|
@ -1,33 +0,0 @@
|
|||||||
experiment:
|
|
||||||
cache_dir: "./cache/"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_project: "runpod"
|
|
||||||
|
|
||||||
training:
|
|
||||||
seed: 42
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 3
|
|
||||||
learning_rate: 0.0005
|
|
||||||
weight_decay: 0.1
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
warmup_steps: 1000
|
|
||||||
fp16: true
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
eval_every: 10000
|
|
||||||
logging_steps: 100
|
|
||||||
|
|
||||||
data:
|
|
||||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
|
||||||
mlm_probability: 0.15
|
|
||||||
num_samples: -1 # -1 means all samples
|
|
||||||
valid_size: 0.05
|
|
||||||
test_size: 0.05
|
|
||||||
|
|
||||||
model:
|
|
||||||
extra_embeddings: true # if false, use the original star encoder
|
|
||||||
sinusoidal_init: false
|
|
||||||
concat_embeddings: false
|
|
||||||
sum_embeddings: true
|
|
||||||
max_depth: 32
|
|
||||||
max_seq_length: 512
|
|
@ -1,33 +0,0 @@
|
|||||||
experiment:
|
|
||||||
cache_dir: "./cache/"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_project: "runpod"
|
|
||||||
|
|
||||||
training:
|
|
||||||
seed: 42
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 3
|
|
||||||
learning_rate: 0.0005
|
|
||||||
weight_decay: 0.1
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
warmup_steps: 1000
|
|
||||||
fp16: true
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
eval_every: 10000
|
|
||||||
logging_steps: 100
|
|
||||||
|
|
||||||
data:
|
|
||||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
|
||||||
mlm_probability: 0.15
|
|
||||||
num_samples: -1 # -1 means all samples
|
|
||||||
valid_size: 0.05
|
|
||||||
test_size: 0.05
|
|
||||||
|
|
||||||
model:
|
|
||||||
extra_embeddings: true # if false, use the original star encoder
|
|
||||||
sinusoidal_init: true
|
|
||||||
concat_embeddings: true
|
|
||||||
sum_embeddings: false
|
|
||||||
max_depth: 32
|
|
||||||
max_seq_length: 512
|
|
@ -1,33 +0,0 @@
|
|||||||
experiment:
|
|
||||||
cache_dir: "./cache/"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_project: "runpod"
|
|
||||||
|
|
||||||
training:
|
|
||||||
seed: 42
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 3
|
|
||||||
learning_rate: 0.0005
|
|
||||||
weight_decay: 0.1
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
warmup_steps: 1000
|
|
||||||
fp16: true
|
|
||||||
|
|
||||||
evaluation:
|
|
||||||
eval_every: 10000
|
|
||||||
logging_steps: 100
|
|
||||||
|
|
||||||
data:
|
|
||||||
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
|
||||||
mlm_probability: 0.15
|
|
||||||
num_samples: -1 # -1 means all samples
|
|
||||||
valid_size: 0.05
|
|
||||||
test_size: 0.05
|
|
||||||
|
|
||||||
model:
|
|
||||||
extra_embeddings: true # if false, use the original star encoder
|
|
||||||
sinusoidal_init: false
|
|
||||||
concat_embeddings: true
|
|
||||||
sum_embeddings: false
|
|
||||||
max_depth: 32
|
|
||||||
max_seq_length: 512
|
|
2
code/data/.gitignore
vendored
Normal file
2
code/data/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*
|
||||||
|
!.gitignore
|
2
code/models/.gitignore
vendored
Normal file
2
code/models/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*
|
||||||
|
!.gitignore
|
@ -19,7 +19,6 @@ dependencies = [
|
|||||||
"tree-sitter-python==0.23.4",
|
"tree-sitter-python==0.23.4",
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"ipywidgets==8.1.5",
|
"ipywidgets==8.1.5",
|
||||||
"pyyaml==6.0.2",
|
|
||||||
]
|
]
|
||||||
requires-python = "==3.11.*"
|
requires-python = "==3.11.*"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -42,4 +41,3 @@ distribution = true
|
|||||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||||
train = {cmd = "src/training.py"}
|
train = {cmd = "src/training.py"}
|
||||||
eval = {cmd = "src/eval_model.py"}
|
eval = {cmd = "src/eval_model.py"}
|
||||||
finetune = {cmd = "src_finetune/finetune.py"}
|
|
||||||
|
21
code/src/config.json
Normal file
21
code/src/config.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"project": "runpod",
|
||||||
|
"run_name": "original",
|
||||||
|
"dataset": "patrykbart/codeparrot-clean-no-comments-starencoder-small",
|
||||||
|
"output_dir": "./outputs/long-no-comments-starencoder-original",
|
||||||
|
"extra_embeddings": false,
|
||||||
|
"seed": 420,
|
||||||
|
"mlm_probability": 0.15,
|
||||||
|
"batch_size": 192,
|
||||||
|
"epochs": 3,
|
||||||
|
"eval_every": 2500,
|
||||||
|
"learning_rate": 5e-4,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"max_grad_norm": 1.0,
|
||||||
|
"warmup_steps": 500,
|
||||||
|
"bf16": true,
|
||||||
|
"logging_steps": 100,
|
||||||
|
"valid_size": 0.05,
|
||||||
|
"test_size": 0.05,
|
||||||
|
"num_samples": -1
|
||||||
|
}
|
@ -120,7 +120,7 @@ def main():
|
|||||||
# Setup paths
|
# Setup paths
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
config = load_config(current_dir / 'eval_config.json')
|
config = load_config(current_dir / 'eval_config.json')
|
||||||
model_dir = Path(config['model_dir']) / 'final-model'
|
model_dir = Path(config['model_dir'])
|
||||||
data_dir = Path(config['data_dir'])
|
data_dir = Path(config['data_dir'])
|
||||||
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
||||||
results_dir.mkdir(exist_ok=True)
|
results_dir.mkdir(exist_ok=True)
|
||||||
|
3754
code/src/node_types.json
Normal file
3754
code/src/node_types.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -84,6 +84,7 @@ def process_example(code, tokenizer):
|
|||||||
|
|
||||||
depths = [-1] * len(input_ids)
|
depths = [-1] * len(input_ids)
|
||||||
sibling_idxs = [-1] * len(input_ids)
|
sibling_idxs = [-1] * len(input_ids)
|
||||||
|
node_types = [None] * len(input_ids)
|
||||||
node_texts = [''] * len(input_ids)
|
node_texts = [''] * len(input_ids)
|
||||||
|
|
||||||
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
|
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
|
||||||
@ -95,6 +96,7 @@ def process_example(code, tokenizer):
|
|||||||
if node.start_byte <= start < node.end_byte:
|
if node.start_byte <= start < node.end_byte:
|
||||||
depths[i] = depth
|
depths[i] = depth
|
||||||
sibling_idxs[i] = sibling_idx
|
sibling_idxs[i] = sibling_idx
|
||||||
|
node_types[i] = node.type
|
||||||
node_texts[i] = code[node.start_byte:node.end_byte]
|
node_texts[i] = code[node.start_byte:node.end_byte]
|
||||||
for i, child in enumerate(node.children):
|
for i, child in enumerate(node.children):
|
||||||
traverse(child, depth + 1, i)
|
traverse(child, depth + 1, i)
|
||||||
@ -107,6 +109,7 @@ def process_example(code, tokenizer):
|
|||||||
'attention_mask': attention_mask,
|
'attention_mask': attention_mask,
|
||||||
'depths': depths,
|
'depths': depths,
|
||||||
'sibling_idxs': sibling_idxs,
|
'sibling_idxs': sibling_idxs,
|
||||||
|
'node_types': node_types,
|
||||||
'node_texts': node_texts
|
'node_texts': node_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,6 +121,7 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths = []
|
processed_depths = []
|
||||||
processed_sibling_idxs = []
|
processed_sibling_idxs = []
|
||||||
processed_node_texts = []
|
processed_node_texts = []
|
||||||
|
processed_node_types = []
|
||||||
|
|
||||||
for content in contents:
|
for content in contents:
|
||||||
try:
|
try:
|
||||||
@ -130,6 +134,7 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths.append([])
|
processed_depths.append([])
|
||||||
processed_sibling_idxs.append([])
|
processed_sibling_idxs.append([])
|
||||||
processed_node_texts.append([])
|
processed_node_texts.append([])
|
||||||
|
processed_node_types.append([])
|
||||||
else:
|
else:
|
||||||
processed_input_ids.append(result['input_ids'])
|
processed_input_ids.append(result['input_ids'])
|
||||||
processed_attention_mask.append(result['attention_mask'])
|
processed_attention_mask.append(result['attention_mask'])
|
||||||
@ -137,6 +142,7 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths.append(result['depths'])
|
processed_depths.append(result['depths'])
|
||||||
processed_sibling_idxs.append(result['sibling_idxs'])
|
processed_sibling_idxs.append(result['sibling_idxs'])
|
||||||
processed_node_texts.append(result['node_texts'])
|
processed_node_texts.append(result['node_texts'])
|
||||||
|
processed_node_types.append(result['node_types'])
|
||||||
except Exception:
|
except Exception:
|
||||||
# If something unexpected happens
|
# If something unexpected happens
|
||||||
processed_input_ids.append([])
|
processed_input_ids.append([])
|
||||||
@ -145,6 +151,7 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths.append([])
|
processed_depths.append([])
|
||||||
processed_sibling_idxs.append([])
|
processed_sibling_idxs.append([])
|
||||||
processed_node_texts.append([])
|
processed_node_texts.append([])
|
||||||
|
processed_node_types.append([])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'input_ids': processed_input_ids,
|
'input_ids': processed_input_ids,
|
||||||
@ -152,6 +159,7 @@ def process_batch(batch, tokenizer):
|
|||||||
'tokens': processed_tokens,
|
'tokens': processed_tokens,
|
||||||
'depths': processed_depths,
|
'depths': processed_depths,
|
||||||
'sibling_idxs': processed_sibling_idxs,
|
'sibling_idxs': processed_sibling_idxs,
|
||||||
|
'node_types': processed_node_types,
|
||||||
'node_texts': processed_node_texts
|
'node_texts': processed_node_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
import os
|
|
||||||
import wandb
|
import wandb
|
||||||
import argparse
|
import json
|
||||||
import yaml
|
|
||||||
import torch
|
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datasets import load_dataset, DatasetDict
|
from datasets import load_from_disk, DatasetDict, load_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -16,29 +12,20 @@ from transformers import (
|
|||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
AutoModelForMaskedLM
|
AutoModelForMaskedLM
|
||||||
)
|
)
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tree_codebert import TreeCodeBERTForPreTraining
|
||||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MonitoringTrainer(Trainer):
|
|
||||||
def training_step(self, model, inputs, num_items_in_batch=None):
|
|
||||||
# Perform the regular training step
|
|
||||||
outputs = super().training_step(model, inputs, num_items_in_batch)
|
|
||||||
|
|
||||||
# Log gradient norms
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.grad is not None:
|
|
||||||
grad_norm = param.grad.norm(2).item()
|
|
||||||
wandb.log({f'grad_norm/{name}': grad_norm})
|
|
||||||
|
|
||||||
# Log weight norms
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
weight_norm = param.data.norm(2).item()
|
|
||||||
wandb.log({f'weight_norm/{name}': weight_norm})
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -50,100 +37,90 @@ def set_seed(seed: int):
|
|||||||
|
|
||||||
def load_config(config_path: Path) -> dict:
|
def load_config(config_path: Path) -> dict:
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
return yaml.safe_load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def initialize_wandb(config, name, files_to_save):
|
def main():
|
||||||
wandb.init(project=config['experiment']['wandb_project'], config=config, name=name)
|
# Setup paths
|
||||||
for file in files_to_save:
|
current_dir = Path(__file__).parent
|
||||||
wandb.save(file)
|
config = load_config(current_dir / 'config.json')
|
||||||
|
output_dir = Path(config['output_dir'])
|
||||||
|
|
||||||
def prepare_dataset(config, cache_dir):
|
# Set seed
|
||||||
dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir)
|
set_seed(config['seed'])
|
||||||
if config['data']['num_samples'] > 0:
|
|
||||||
dataset = dataset.select(range(config['data']['num_samples']))
|
# Initialize W&B and save files
|
||||||
train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size'])
|
wandb.init(project=config['project'], config=config, name=config['run_name'])
|
||||||
|
for file in [__file__, 'config.json', 'tree_starencoder.py']:
|
||||||
|
if config['extra_embeddings'] or file != 'tree_starencoder.py':
|
||||||
|
wandb.save(current_dir / file)
|
||||||
|
|
||||||
|
# Simplified dataset splitting
|
||||||
|
dataset = load_dataset(config['dataset'], split='train')
|
||||||
|
if config['num_samples'] > 0:
|
||||||
|
dataset = dataset.select(range(config['num_samples']))
|
||||||
|
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
||||||
test_valid = train_testvalid['test'].train_test_split(
|
test_valid = train_testvalid['test'].train_test_split(
|
||||||
test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']),
|
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
|
||||||
seed=config['training']['seed']
|
seed=config['seed']
|
||||||
)
|
)
|
||||||
dataset = DatasetDict({
|
dataset = DatasetDict({
|
||||||
'train': train_testvalid['train'],
|
'train': train_testvalid['train'],
|
||||||
'test': test_valid['test'],
|
'test': test_valid['test'],
|
||||||
'valid': test_valid['train'],
|
'valid': test_valid['train'],
|
||||||
})
|
})
|
||||||
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
|
|
||||||
if config['model']['extra_embeddings']:
|
|
||||||
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
|
# Continue with the rest of processing
|
||||||
|
columns_to_remove = dataset['train'].column_names
|
||||||
|
columns_to_remove.remove('input_ids')
|
||||||
|
columns_to_remove.remove('attention_mask')
|
||||||
|
if config['extra_embeddings']:
|
||||||
|
columns_to_remove.remove('depths')
|
||||||
|
columns_to_remove.remove('sibling_idxs')
|
||||||
dataset = dataset.remove_columns(columns_to_remove)
|
dataset = dataset.remove_columns(columns_to_remove)
|
||||||
return dataset
|
logger.info(f'Loaded dataset:\n{dataset}')
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
|
|
||||||
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
config_path = Path(args.config)
|
|
||||||
config = load_config(config_path)
|
|
||||||
cache_dir = Path(config['experiment']['cache_dir'])
|
|
||||||
output_dir = Path('./outputs') / config_path.stem
|
|
||||||
|
|
||||||
if config['experiment']['use_wandb']:
|
|
||||||
os.environ['WANDB_MODE'] = 'online'
|
|
||||||
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
|
|
||||||
else:
|
|
||||||
os.environ['WANDB_MODE'] = 'offline'
|
|
||||||
logger.info('Wandb is not used.')
|
|
||||||
|
|
||||||
set_seed(config['training']['seed'])
|
|
||||||
|
|
||||||
dataset = prepare_dataset(config, cache_dir)
|
|
||||||
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
|
|
||||||
logger.info(f'Dataset columns: {dataset["train"].column_names}')
|
|
||||||
|
|
||||||
|
# Simplify tokenizer setup
|
||||||
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
||||||
if tokenizer.mask_token is None:
|
tokenizer.add_special_tokens({'mask_token': '<mask>'}) if tokenizer.mask_token is None else None
|
||||||
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
|
||||||
tokenizer.mask_token = '<mask>'
|
|
||||||
logger.info("Added '<mask>' as the mask token.")
|
|
||||||
if tokenizer.pad_token is None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
logger.info("Set padding token to be the same as the EOS token.")
|
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir)
|
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
|
||||||
if config['model']['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
model = TreeStarEncoderForPreTraining(model_config, yaml_config=config)
|
model = TreeStarEncoderForPreTraining(model_config)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForMaskedLM.from_config(model_config)
|
model = AutoModelForMaskedLM.from_config(model_config)
|
||||||
logger.info(f'Loaded model: {model.__class__.__name__}')
|
logger.info(f'Loaded model: {model.__class__.__name__}')
|
||||||
|
|
||||||
|
# Setup training arguments
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=str(output_dir),
|
output_dir=str(output_dir),
|
||||||
per_device_train_batch_size=config['training']['batch_size'],
|
per_device_train_batch_size=config['batch_size'],
|
||||||
per_device_eval_batch_size=config['training']['batch_size'],
|
per_device_eval_batch_size=config['batch_size'],
|
||||||
learning_rate=config['training']['learning_rate'],
|
learning_rate=config['learning_rate'],
|
||||||
weight_decay=config['training']['weight_decay'],
|
weight_decay=config['weight_decay'],
|
||||||
num_train_epochs=config['training']['epochs'],
|
num_train_epochs=config['epochs'],
|
||||||
warmup_steps=config['training']['warmup_steps'],
|
warmup_steps=config['warmup_steps'],
|
||||||
max_grad_norm=config['training']['max_grad_norm'],
|
max_grad_norm=config['max_grad_norm'],
|
||||||
logging_steps=config['evaluation']['logging_steps'],
|
logging_steps=config['logging_steps'],
|
||||||
eval_steps=config['evaluation']['eval_every'],
|
eval_steps=config['eval_every'],
|
||||||
save_steps=config['evaluation']['eval_every'],
|
save_steps=config['eval_every'],
|
||||||
eval_strategy='steps',
|
eval_strategy='steps',
|
||||||
save_strategy='steps',
|
save_strategy='steps',
|
||||||
|
save_total_limit=5,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
report_to='wandb' if config['experiment']['use_wandb'] else None,
|
report_to='wandb',
|
||||||
run_name=config_path.stem,
|
run_name=config['run_name'],
|
||||||
seed=config['training']['seed'],
|
seed=config['seed'],
|
||||||
fp16=config['training']['fp16'],
|
bf16=config['bf16'],
|
||||||
dataloader_num_workers=8,
|
dataloader_num_workers=8,
|
||||||
gradient_checkpointing=True,
|
gradient_checkpointing=True,
|
||||||
metric_for_best_model='eval_loss',
|
metric_for_best_model='eval_loss',
|
||||||
greater_is_better=False,
|
greater_is_better=False,
|
||||||
save_total_limit=3,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = MonitoringTrainer(
|
# Create trainer
|
||||||
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=dataset['train'],
|
train_dataset=dataset['train'],
|
||||||
@ -151,25 +128,32 @@ def main():
|
|||||||
data_collator=DataCollatorForLanguageModeling(
|
data_collator=DataCollatorForLanguageModeling(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
mlm=True,
|
mlm=True,
|
||||||
mlm_probability=config['data']['mlm_probability']
|
mlm_probability=config['mlm_probability']
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
logger.info('Starting training...')
|
logger.info('Starting training...')
|
||||||
trainer.train()
|
trainer.train()
|
||||||
logger.info('Training completed.')
|
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
logger.info('Evaluating on test set...')
|
logger.info('Evaluating on test set...')
|
||||||
eval_results = trainer.evaluate(dataset['test'])
|
eval_results = trainer.evaluate(dataset['test'])
|
||||||
logger.info(f'Evaluation results: {eval_results}')
|
|
||||||
wandb.log({'test_loss': eval_results['eval_loss']})
|
wandb.log({'test_loss': eval_results['eval_loss']})
|
||||||
|
logger.info(f'Test loss: {eval_results["eval_loss"]}')
|
||||||
|
|
||||||
|
# Save final model
|
||||||
logger.info('Saving final model...')
|
logger.info('Saving final model...')
|
||||||
trainer.save_model(output_dir / 'final-model')
|
trainer.save_model(output_dir / 'final-model')
|
||||||
tokenizer.save_pretrained(output_dir / 'final-model')
|
tokenizer.save_pretrained(output_dir / 'final-model')
|
||||||
|
|
||||||
# Upload to W&B
|
# Zip and upload the final model to W&B
|
||||||
wandb.save(output_dir / 'final-model/*')
|
with zipfile.ZipFile(output_dir / 'final-model.zip', 'w') as zipf:
|
||||||
|
for file in (output_dir / 'final-model').glob('**/*'):
|
||||||
|
zipf.write(file, arcname=file.name)
|
||||||
|
wandb.save(output_dir / 'final-model.zip')
|
||||||
|
|
||||||
|
logger.info('Training completed!')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
208
code/src/tree_codebert.py
Normal file
208
code/src/tree_codebert.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import wandb
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from transformers import RobertaConfig, RobertaForMaskedLM
|
||||||
|
|
||||||
|
class TreePositionalEmbedding(nn.Module):
|
||||||
|
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.max_depth = max_depth
|
||||||
|
|
||||||
|
# Separate embeddings for different features
|
||||||
|
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
|
||||||
|
# Improved projection layers
|
||||||
|
self.node_projection = nn.Sequential(
|
||||||
|
nn.Linear(d_model * 2, d_model * 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(d_model * 2, d_model),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Layer norm for stability
|
||||||
|
self.layer_norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
self._initialize_embeddings()
|
||||||
|
|
||||||
|
def _initialize_embeddings(self):
|
||||||
|
std = 0.02
|
||||||
|
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||||
|
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||||
|
|
||||||
|
# Initialize projection layers
|
||||||
|
for layer in self.node_projection:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
||||||
|
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
||||||
|
"""
|
||||||
|
# Clamp values to max_depth
|
||||||
|
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||||
|
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||||
|
|
||||||
|
# Get embeddings for each feature
|
||||||
|
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
||||||
|
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
||||||
|
|
||||||
|
# Combine features
|
||||||
|
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||||
|
embeddings = self.node_projection(combined)
|
||||||
|
|
||||||
|
# Apply layer norm
|
||||||
|
normalized_embeddings = self.layer_norm(embeddings)
|
||||||
|
|
||||||
|
return normalized_embeddings
|
||||||
|
|
||||||
|
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||||
|
"""CodeBERT model enhanced with tree-structural information."""
|
||||||
|
|
||||||
|
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||||
|
d_model=config.hidden_size,
|
||||||
|
max_depth=max_depth,
|
||||||
|
dropout=config.hidden_dropout_prob
|
||||||
|
)
|
||||||
|
|
||||||
|
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||||
|
|
||||||
|
# Initialize sequential position embeddings with sinusoidal pattern
|
||||||
|
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||||
|
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||||
|
|
||||||
|
# Initialize embedding weights equally
|
||||||
|
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
||||||
|
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
||||||
|
|
||||||
|
# Layer norms for embedding combination
|
||||||
|
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
|
def get_normalized_weights(self):
|
||||||
|
"""Get softmaxed weights for embedding combination."""
|
||||||
|
return torch.softmax(self.embedding_weights, dim=0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
depths: Optional[torch.Tensor] = None,
|
||||||
|
sibling_idxs: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Get normalized weights for embedding combination and calculate regularization
|
||||||
|
weights = self.get_normalized_weights()
|
||||||
|
|
||||||
|
# Calculate weight variance regularization
|
||||||
|
# We want weights to remain somewhat balanced, so penalize high variance
|
||||||
|
weight_variance = torch.var(weights)
|
||||||
|
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
||||||
|
|
||||||
|
# Add L2 regularization to prevent any weight from getting too close to 1
|
||||||
|
# This helps maintain a more balanced contribution from each embedding type
|
||||||
|
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
||||||
|
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
||||||
|
|
||||||
|
# Get token embeddings
|
||||||
|
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||||
|
token_embeddings = self.pre_combination_norm(token_embeddings)
|
||||||
|
|
||||||
|
# Get sequential position embeddings
|
||||||
|
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||||
|
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||||
|
|
||||||
|
# Get tree positional embeddings if tree information is provided
|
||||||
|
if depths is not None and sibling_idxs is not None:
|
||||||
|
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||||
|
else:
|
||||||
|
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||||
|
|
||||||
|
# Combine all embeddings using learned weights
|
||||||
|
combined_embeddings = (
|
||||||
|
weights[0] * token_embeddings +
|
||||||
|
weights[1] * tree_embeddings +
|
||||||
|
weights[2] * seq_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
||||||
|
|
||||||
|
# Forward pass through base model
|
||||||
|
outputs = self.roberta(
|
||||||
|
inputs_embeds=combined_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.lm_head(sequence_output)
|
||||||
|
|
||||||
|
# Calculate MLM loss if labels are provided
|
||||||
|
masked_lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
masked_lm_loss = loss_fct(
|
||||||
|
prediction_scores.view(-1, self.config.vocab_size),
|
||||||
|
labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add regularization losses to final loss
|
||||||
|
if masked_lm_loss is not None:
|
||||||
|
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
||||||
|
else:
|
||||||
|
final_loss = weight_reg_loss + l2_reg_loss
|
||||||
|
else:
|
||||||
|
final_loss = None
|
||||||
|
|
||||||
|
# Prepare embedding weights for logging
|
||||||
|
weights_cpu = weights.detach().cpu()
|
||||||
|
embedding_weights = {
|
||||||
|
"token": weights_cpu[0].item(),
|
||||||
|
"tree": weights_cpu[1].item(),
|
||||||
|
"sequential": weights_cpu[2].item()
|
||||||
|
}
|
||||||
|
reg_metrics = {
|
||||||
|
"weight_variance": weight_variance.item(),
|
||||||
|
"max_weight_penalty": max_weight_penalty.item(),
|
||||||
|
"weight_reg_loss": weight_reg_loss.item(),
|
||||||
|
"l2_reg_loss": l2_reg_loss.item()
|
||||||
|
}
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
|
||||||
|
step=kwargs.get("global_step", None)
|
||||||
|
)
|
||||||
|
wandb.log(
|
||||||
|
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
|
||||||
|
step=kwargs.get("global_step", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": final_loss,
|
||||||
|
"logits": prediction_scores,
|
||||||
|
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||||
|
"attentions": outputs.attentions if output_attentions else None,
|
||||||
|
}
|
@ -1,107 +1,48 @@
|
|||||||
|
import wandb
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from transformers import AutoConfig, BertForMaskedLM, GenerationMixin
|
from transformers import AutoConfig, BertForMaskedLM
|
||||||
|
|
||||||
class TreePositionalEmbedding(nn.Module):
|
from tree_codebert import TreePositionalEmbedding
|
||||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.d_model = d_model
|
|
||||||
self.max_depth = max_depth
|
|
||||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
|
||||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
|
||||||
self.node_projection = nn.Sequential(
|
|
||||||
nn.Linear(d_model * 2, d_model * 2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(d_model * 2, d_model),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
self.layer_norm = nn.LayerNorm(d_model)
|
|
||||||
self._initialize_embeddings()
|
|
||||||
|
|
||||||
def _initialize_embeddings(self):
|
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||||
std = 0.02
|
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
|
||||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
|
||||||
for layer in self.node_projection:
|
|
||||||
if isinstance(layer, nn.Linear):
|
|
||||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
|
||||||
nn.init.zeros_(layer.bias)
|
|
||||||
|
|
||||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
|
||||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
|
||||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
|
||||||
depth_embeddings = self.depth_embedding(depths)
|
|
||||||
sibling_embeddings = self.sibling_embedding(sibling_idxs)
|
|
||||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
|
||||||
embeddings = self.node_projection(combined)
|
|
||||||
return self.layer_norm(embeddings)
|
|
||||||
|
|
||||||
class NewEmbeddings(nn.Module):
|
|
||||||
"""Construct the embeddings from word, position and tree-based embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, yaml_config):
|
|
||||||
super().__init__()
|
|
||||||
self.yaml_config = yaml_config
|
|
||||||
|
|
||||||
if self.yaml_config['model']['concat_embeddings']:
|
|
||||||
self.fusion_layer = nn.Sequential(
|
|
||||||
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
||||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
|
||||||
d_model=config.hidden_size,
|
|
||||||
max_depth=yaml_config['model']['max_depth'],
|
|
||||||
dropout=config.hidden_dropout_prob
|
|
||||||
)
|
|
||||||
|
|
||||||
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
||||||
|
|
||||||
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
|
|
||||||
seq_length = input_shape[1]
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
|
||||||
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
|
||||||
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
|
||||||
|
|
||||||
if self.yaml_config['model']['sum_embeddings']:
|
|
||||||
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
|
|
||||||
if self.yaml_config['model']['concat_embeddings']:
|
|
||||||
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
|
|
||||||
embeddings = self.fusion_layer(embeddings)
|
|
||||||
|
|
||||||
embeddings = self.LayerNorm(embeddings)
|
|
||||||
embeddings = self.dropout(embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
|
||||||
def __init__(self, config: AutoConfig, yaml_config: Dict):
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.yaml_config = yaml_config
|
|
||||||
self.embeddings = NewEmbeddings(config, yaml_config)
|
|
||||||
|
|
||||||
# self.fusion_layer = nn.Sequential(
|
# self.fusion_layer = nn.Sequential(
|
||||||
# nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
|
# nn.Linear(config.hidden_size * 4, config.hidden_size),
|
||||||
# nn.GELU(),
|
# nn.GELU(),
|
||||||
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
|
|
||||||
# nn.Dropout(config.hidden_dropout_prob),
|
# nn.Dropout(config.hidden_dropout_prob),
|
||||||
# nn.LayerNorm(config.hidden_size)
|
# nn.LayerNorm(config.hidden_size)
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# Override config to set max_seq_length
|
||||||
|
config.max_position_embeddings = max_seq_length
|
||||||
|
|
||||||
|
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||||
|
d_model=config.hidden_size,
|
||||||
|
max_depth=max_depth,
|
||||||
|
dropout=config.hidden_dropout_prob
|
||||||
|
)
|
||||||
|
|
||||||
|
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||||
|
|
||||||
|
# # Initialize sequential position embeddings with sinusoidal pattern
|
||||||
|
# position = torch.arange(max_seq_length).unsqueeze(1)
|
||||||
|
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||||
|
# pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||||
|
# pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
# pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
# self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||||
|
|
||||||
|
# New node type embeddings
|
||||||
|
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -112,10 +53,34 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
embedding_output = self.embeddings(input_ids, depths, sibling_idxs)
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Get token embeddings
|
||||||
|
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
# Get sequential position embeddings
|
||||||
|
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||||
|
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||||
|
|
||||||
|
# Get tree positional embeddings
|
||||||
|
if depths is not None and sibling_idxs is not None:
|
||||||
|
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||||
|
else:
|
||||||
|
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||||
|
|
||||||
|
# # Get node type embeddings
|
||||||
|
# node_type_embeddings = self.node_type_embeddings(node_types)
|
||||||
|
|
||||||
|
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
||||||
|
# combined_embeddings = self.fusion_layer(combined)
|
||||||
|
|
||||||
|
# Add the embeddings instead of concatenating
|
||||||
|
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
|
||||||
|
|
||||||
|
combined_embeddings = self.norm(combined_embeddings)
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
inputs_embeds=embedding_output,
|
inputs_embeds=combined_embeddings,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -135,7 +100,6 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
|||||||
return {
|
return {
|
||||||
"loss": masked_lm_loss,
|
"loss": masked_lm_loss,
|
||||||
"logits": prediction_scores,
|
"logits": prediction_scores,
|
||||||
"last_hidden_state": sequence_output,
|
|
||||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||||
"attentions": outputs.attentions if output_attentions else None,
|
"attentions": outputs.attentions if output_attentions else None,
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user