Compare commits
9 Commits
runpod-exp
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
a6ebe812cf | ||
|
19864729cf | ||
|
2639e8dca2 | ||
|
2ec7cc263d | ||
|
03b2502af0 | ||
|
3b93a7cc8a | ||
|
76a89dc236 | ||
|
eed2096400 | ||
|
f0679ab861 |
1
code/.gitignore
vendored
1
code/.gitignore
vendored
@ -164,3 +164,4 @@ cython_debug/
|
|||||||
# Weights & Biases
|
# Weights & Biases
|
||||||
wandb/
|
wandb/
|
||||||
outputs/
|
outputs/
|
||||||
|
cache/
|
||||||
|
@ -20,9 +20,9 @@ pdm install
|
|||||||
```
|
```
|
||||||
### 4. Run training code
|
### 4. Run training code
|
||||||
```bash
|
```bash
|
||||||
pdm run_training
|
pdm train --config {CONFIG FILE}
|
||||||
```
|
```
|
||||||
or
|
Example:
|
||||||
```
|
```bash
|
||||||
pdm run src/train_codebert_mlm.py
|
pdm train --config ./configs/original.yaml
|
||||||
```
|
```
|
33
code/configs/finetune.yaml
Normal file
33
code/configs/finetune.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
experiment:
|
||||||
|
cache_dir: "./cache/"
|
||||||
|
use_wandb: false
|
||||||
|
wandb_project: "prof-gralinski"
|
||||||
|
|
||||||
|
training:
|
||||||
|
seed: 42
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 1
|
||||||
|
learning_rate: 0.0005
|
||||||
|
weight_decay: 0.1
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
warmup_steps: 0
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
eval_every: 10000
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
source: "patrykbart/1-fold-clone-detection-600k-5fold"
|
||||||
|
num_samples: 1000 # -1 means all samples
|
||||||
|
valid_size: 0.05
|
||||||
|
test_size: 0.05
|
||||||
|
|
||||||
|
model:
|
||||||
|
path: "outputs/tmp/final-model"
|
||||||
|
extra_embeddings: true # if false, use the original star encoder
|
||||||
|
sinusoidal_init: false
|
||||||
|
concat_embeddings: true
|
||||||
|
sum_embeddings: false
|
||||||
|
max_depth: 32
|
||||||
|
max_seq_length: 512
|
33
code/configs/original.yaml
Normal file
33
code/configs/original.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
experiment:
|
||||||
|
cache_dir: "./cache/"
|
||||||
|
use_wandb: true
|
||||||
|
wandb_project: "runpod"
|
||||||
|
|
||||||
|
training:
|
||||||
|
seed: 42
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 3
|
||||||
|
learning_rate: 0.0005
|
||||||
|
weight_decay: 0.1
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
warmup_steps: 1000
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
eval_every: 10000
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||||
|
mlm_probability: 0.15
|
||||||
|
num_samples: -1 # -1 means all samples
|
||||||
|
valid_size: 0.05
|
||||||
|
test_size: 0.05
|
||||||
|
|
||||||
|
model:
|
||||||
|
extra_embeddings: false # if false, use the original star encoder
|
||||||
|
sinusoidal_init: false
|
||||||
|
concat_embeddings: false
|
||||||
|
sum_embeddings: false
|
||||||
|
max_depth: 32
|
||||||
|
max_seq_length: 512
|
33
code/configs/tree-add-sinusoidal.yaml
Normal file
33
code/configs/tree-add-sinusoidal.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
experiment:
|
||||||
|
cache_dir: "./cache/"
|
||||||
|
use_wandb: true
|
||||||
|
wandb_project: "runpod"
|
||||||
|
|
||||||
|
training:
|
||||||
|
seed: 42
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 3
|
||||||
|
learning_rate: 0.0005
|
||||||
|
weight_decay: 0.1
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
warmup_steps: 1000
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
eval_every: 10000
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||||
|
mlm_probability: 0.15
|
||||||
|
num_samples: -1 # -1 means all samples
|
||||||
|
valid_size: 0.05
|
||||||
|
test_size: 0.05
|
||||||
|
|
||||||
|
model:
|
||||||
|
extra_embeddings: true # if false, use the original star encoder
|
||||||
|
sinusoidal_init: true
|
||||||
|
concat_embeddings: false
|
||||||
|
sum_embeddings: true
|
||||||
|
max_depth: 32
|
||||||
|
max_seq_length: 512
|
33
code/configs/tree-add.yaml
Normal file
33
code/configs/tree-add.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
experiment:
|
||||||
|
cache_dir: "./cache/"
|
||||||
|
use_wandb: true
|
||||||
|
wandb_project: "runpod"
|
||||||
|
|
||||||
|
training:
|
||||||
|
seed: 42
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 3
|
||||||
|
learning_rate: 0.0005
|
||||||
|
weight_decay: 0.1
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
warmup_steps: 1000
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
eval_every: 10000
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||||
|
mlm_probability: 0.15
|
||||||
|
num_samples: -1 # -1 means all samples
|
||||||
|
valid_size: 0.05
|
||||||
|
test_size: 0.05
|
||||||
|
|
||||||
|
model:
|
||||||
|
extra_embeddings: true # if false, use the original star encoder
|
||||||
|
sinusoidal_init: false
|
||||||
|
concat_embeddings: false
|
||||||
|
sum_embeddings: true
|
||||||
|
max_depth: 32
|
||||||
|
max_seq_length: 512
|
33
code/configs/tree-concat-sinusoidal.yaml
Normal file
33
code/configs/tree-concat-sinusoidal.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
experiment:
|
||||||
|
cache_dir: "./cache/"
|
||||||
|
use_wandb: true
|
||||||
|
wandb_project: "runpod"
|
||||||
|
|
||||||
|
training:
|
||||||
|
seed: 42
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 3
|
||||||
|
learning_rate: 0.0005
|
||||||
|
weight_decay: 0.1
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
warmup_steps: 1000
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
eval_every: 10000
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||||
|
mlm_probability: 0.15
|
||||||
|
num_samples: -1 # -1 means all samples
|
||||||
|
valid_size: 0.05
|
||||||
|
test_size: 0.05
|
||||||
|
|
||||||
|
model:
|
||||||
|
extra_embeddings: true # if false, use the original star encoder
|
||||||
|
sinusoidal_init: true
|
||||||
|
concat_embeddings: true
|
||||||
|
sum_embeddings: false
|
||||||
|
max_depth: 32
|
||||||
|
max_seq_length: 512
|
33
code/configs/tree-concat.yaml
Normal file
33
code/configs/tree-concat.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
experiment:
|
||||||
|
cache_dir: "./cache/"
|
||||||
|
use_wandb: true
|
||||||
|
wandb_project: "runpod"
|
||||||
|
|
||||||
|
training:
|
||||||
|
seed: 42
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 3
|
||||||
|
learning_rate: 0.0005
|
||||||
|
weight_decay: 0.1
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
warmup_steps: 1000
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
eval_every: 10000
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
source: "patrykbart/codeparrot-clean-no-comments-starencoder-small"
|
||||||
|
mlm_probability: 0.15
|
||||||
|
num_samples: -1 # -1 means all samples
|
||||||
|
valid_size: 0.05
|
||||||
|
test_size: 0.05
|
||||||
|
|
||||||
|
model:
|
||||||
|
extra_embeddings: true # if false, use the original star encoder
|
||||||
|
sinusoidal_init: false
|
||||||
|
concat_embeddings: true
|
||||||
|
sum_embeddings: false
|
||||||
|
max_depth: 32
|
||||||
|
max_seq_length: 512
|
2
code/data/.gitignore
vendored
2
code/data/.gitignore
vendored
@ -1,2 +0,0 @@
|
|||||||
*
|
|
||||||
!.gitignore
|
|
2
code/models/.gitignore
vendored
2
code/models/.gitignore
vendored
@ -1,2 +0,0 @@
|
|||||||
*
|
|
||||||
!.gitignore
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||||||
"tree-sitter-python==0.23.4",
|
"tree-sitter-python==0.23.4",
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"ipywidgets==8.1.5",
|
"ipywidgets==8.1.5",
|
||||||
|
"pyyaml==6.0.2",
|
||||||
]
|
]
|
||||||
requires-python = "==3.11.*"
|
requires-python = "==3.11.*"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -41,3 +42,4 @@ distribution = true
|
|||||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||||
train = {cmd = "src/training.py"}
|
train = {cmd = "src/training.py"}
|
||||||
eval = {cmd = "src/eval_model.py"}
|
eval = {cmd = "src/eval_model.py"}
|
||||||
|
finetune = {cmd = "src_finetune/finetune.py"}
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"extra_embeddings": true,
|
|
||||||
"run_name": "no-sinusoidal",
|
|
||||||
"data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/",
|
|
||||||
"output_dir": "./outputs/long-no-comments-starencoder-no-sinusoidal",
|
|
||||||
"seed": 420,
|
|
||||||
"mlm_probability": 0.15,
|
|
||||||
"batch_size": 32,
|
|
||||||
"epochs": 3,
|
|
||||||
"eval_every": 10000,
|
|
||||||
"learning_rate": 5e-4,
|
|
||||||
"weight_decay": 0.1,
|
|
||||||
"max_grad_norm": 1.0,
|
|
||||||
"warmup_steps": 1000,
|
|
||||||
"fp16": true,
|
|
||||||
"logging_steps": 100,
|
|
||||||
"valid_size": 0.05,
|
|
||||||
"test_size": 0.05,
|
|
||||||
"num_samples": -1
|
|
||||||
}
|
|
@ -1,118 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import multiprocessing
|
|
||||||
from pathlib import Path
|
|
||||||
from datasets import load_from_disk
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def load_node_types_from_json(json_path: Path):
|
|
||||||
"""
|
|
||||||
Load node types from the Tree-sitter grammar's `node_types.json` and include UNK as the 0 index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_path (Path): Path to the `node_types.json` file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A mapping from node type strings to unique integer IDs.
|
|
||||||
"""
|
|
||||||
if not json_path.exists():
|
|
||||||
raise FileNotFoundError(f"{json_path} not found.")
|
|
||||||
|
|
||||||
logger.info(f"Loading node types from {json_path}...")
|
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
|
||||||
node_types_data = json.load(f)
|
|
||||||
|
|
||||||
# Extract all unique "type" entries
|
|
||||||
node_types = set()
|
|
||||||
|
|
||||||
def extract_types(data):
|
|
||||||
if isinstance(data, list):
|
|
||||||
for item in data:
|
|
||||||
extract_types(item)
|
|
||||||
elif isinstance(data, dict):
|
|
||||||
if "type" in data and isinstance(data["type"], str):
|
|
||||||
node_types.add(data["type"])
|
|
||||||
for key, value in data.items():
|
|
||||||
extract_types(value)
|
|
||||||
|
|
||||||
extract_types(node_types_data)
|
|
||||||
|
|
||||||
# Create mapping and add 'UNK' at index 0
|
|
||||||
node_type2id = {"<UNK>": 0}
|
|
||||||
for i, node_type in enumerate(sorted(node_types), start=1):
|
|
||||||
node_type2id[node_type] = i
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(node_type2id)} node types, including UNK.")
|
|
||||||
return node_type2id
|
|
||||||
|
|
||||||
def encode_node_types(examples, node_type2id):
|
|
||||||
"""
|
|
||||||
Batched function to replace node type strings with their integer IDs using a preloaded mapping.
|
|
||||||
"""
|
|
||||||
encoded_node_types = []
|
|
||||||
for node_list in examples["node_types"]:
|
|
||||||
try:
|
|
||||||
encoded_node_list = [node_type2id[nt] if nt is not None and nt != 'ERROR' else node_type2id['<UNK>'] for nt in node_list]
|
|
||||||
encoded_node_types.append(encoded_node_list)
|
|
||||||
except KeyError as e:
|
|
||||||
raise KeyError(f"Unknown node type encountered: {e}")
|
|
||||||
examples["node_types_encoded"] = encoded_node_types
|
|
||||||
return examples
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""
|
|
||||||
Main script to load, process, and save a dataset with node types encoded as integers.
|
|
||||||
"""
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# 1. Setup paths & load dataset
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
input_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-padded"
|
|
||||||
output_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-encoded"
|
|
||||||
node_types_path = current_dir / "node_types.json"
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
logger.info(f"Loading dataset from {input_dir}...")
|
|
||||||
dataset = load_from_disk(str(input_dir))
|
|
||||||
logger.info("Dataset loaded.")
|
|
||||||
|
|
||||||
# Determine number of processes to use
|
|
||||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
|
||||||
logger.info(f"Using {num_proc} processes.")
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# 2. Load node types from JSON
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
node_type2id = load_node_types_from_json(node_types_path)
|
|
||||||
logger.info(f"Loaded {len(node_type2id)} node types.")
|
|
||||||
# Save node_type2id to disk
|
|
||||||
with open(output_dir / "node_type2id.json", "w") as f:
|
|
||||||
json.dump(node_type2id, f)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# 3. Convert node types in the dataset to integer IDs
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
logger.info("Converting node type strings to integer IDs...")
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
lambda examples: encode_node_types(examples, node_type2id),
|
|
||||||
batched=True,
|
|
||||||
num_proc=num_proc,
|
|
||||||
desc="Encoding node types to integer IDs",
|
|
||||||
)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# 4. Save the modified dataset to disk
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
logger.info(f"Saving updated dataset to {output_dir}...")
|
|
||||||
dataset.save_to_disk(str(output_dir))
|
|
||||||
logger.info("Dataset saved successfully.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -133,7 +133,7 @@ def main():
|
|||||||
model_config.max_position_embeddings = 1024
|
model_config.max_position_embeddings = 1024
|
||||||
|
|
||||||
if config['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
model = TreeStarEncoderForPreTraining(config=model_config, log=False)
|
model = TreeStarEncoderForPreTraining(config=model_config)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForMaskedLM.from_config(model_config)
|
model = AutoModelForMaskedLM.from_config(model_config)
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,77 +0,0 @@
|
|||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from datasets import load_from_disk
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
import multiprocessing
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def pad_and_save_dataset(input_dir, output_dir, tokenizer_name='bigcode/starencoder', max_length=512):
|
|
||||||
# Load the processed dataset
|
|
||||||
logger.info(f"Loading processed dataset from {input_dir}...")
|
|
||||||
dataset = load_from_disk(input_dir)
|
|
||||||
logger.info(f"Loaded dataset with {len(dataset)} examples")
|
|
||||||
|
|
||||||
# Initialize tokenizer
|
|
||||||
logger.info("Initializing tokenizer...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
logger.info("Loaded StarEncoder tokenizer")
|
|
||||||
|
|
||||||
# Define number of processes
|
|
||||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
|
||||||
logger.info(f"Using {num_proc} processes")
|
|
||||||
|
|
||||||
# Define a function to pad the sequences
|
|
||||||
def pad_sequences(batch):
|
|
||||||
# Convert input_ids back to text if necessary
|
|
||||||
texts = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
|
|
||||||
|
|
||||||
# Use the tokenizer's __call__ method for padding
|
|
||||||
padded_inputs = tokenizer(
|
|
||||||
texts,
|
|
||||||
padding='max_length',
|
|
||||||
max_length=max_length,
|
|
||||||
return_tensors='pt',
|
|
||||||
truncation=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pad other fields with default values
|
|
||||||
padded_depths = [seq + [-1] * (max_length - len(seq)) for seq in batch['depths']]
|
|
||||||
padded_sibling_idxs = [seq + [-1] * (max_length - len(seq)) for seq in batch['sibling_idxs']]
|
|
||||||
padded_node_types = [seq + [None] * (max_length - len(seq)) for seq in batch['node_types']]
|
|
||||||
padded_node_texts = [seq + [''] * (max_length - len(seq)) for seq in batch['node_texts']]
|
|
||||||
|
|
||||||
return {
|
|
||||||
'input_ids': padded_inputs['input_ids'].tolist(),
|
|
||||||
'attention_mask': padded_inputs['attention_mask'].tolist(),
|
|
||||||
'depths': padded_depths,
|
|
||||||
'sibling_idxs': padded_sibling_idxs,
|
|
||||||
'node_types': padded_node_types,
|
|
||||||
'node_texts': padded_node_texts
|
|
||||||
}
|
|
||||||
|
|
||||||
# Apply padding
|
|
||||||
logger.info("Applying padding to dataset...")
|
|
||||||
padded_dataset = dataset.map(
|
|
||||||
pad_sequences,
|
|
||||||
batched=True,
|
|
||||||
desc="Padding dataset",
|
|
||||||
num_proc=num_proc
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the padded dataset
|
|
||||||
logger.info(f"Saving padded dataset to {output_dir}...")
|
|
||||||
padded_dataset.save_to_disk(output_dir)
|
|
||||||
logger.info(f"Saved padded dataset to {output_dir}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
input_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes'
|
|
||||||
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes-padded'
|
|
||||||
|
|
||||||
pad_and_save_dataset(input_dir, output_dir)
|
|
@ -84,7 +84,6 @@ def process_example(code, tokenizer):
|
|||||||
|
|
||||||
depths = [-1] * len(input_ids)
|
depths = [-1] * len(input_ids)
|
||||||
sibling_idxs = [-1] * len(input_ids)
|
sibling_idxs = [-1] * len(input_ids)
|
||||||
node_types = [None] * len(input_ids)
|
|
||||||
node_texts = [''] * len(input_ids)
|
node_texts = [''] * len(input_ids)
|
||||||
|
|
||||||
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
|
tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids)
|
||||||
@ -96,7 +95,6 @@ def process_example(code, tokenizer):
|
|||||||
if node.start_byte <= start < node.end_byte:
|
if node.start_byte <= start < node.end_byte:
|
||||||
depths[i] = depth
|
depths[i] = depth
|
||||||
sibling_idxs[i] = sibling_idx
|
sibling_idxs[i] = sibling_idx
|
||||||
node_types[i] = node.type
|
|
||||||
node_texts[i] = code[node.start_byte:node.end_byte]
|
node_texts[i] = code[node.start_byte:node.end_byte]
|
||||||
for i, child in enumerate(node.children):
|
for i, child in enumerate(node.children):
|
||||||
traverse(child, depth + 1, i)
|
traverse(child, depth + 1, i)
|
||||||
@ -109,7 +107,6 @@ def process_example(code, tokenizer):
|
|||||||
'attention_mask': attention_mask,
|
'attention_mask': attention_mask,
|
||||||
'depths': depths,
|
'depths': depths,
|
||||||
'sibling_idxs': sibling_idxs,
|
'sibling_idxs': sibling_idxs,
|
||||||
'node_types': node_types,
|
|
||||||
'node_texts': node_texts
|
'node_texts': node_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,7 +118,6 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths = []
|
processed_depths = []
|
||||||
processed_sibling_idxs = []
|
processed_sibling_idxs = []
|
||||||
processed_node_texts = []
|
processed_node_texts = []
|
||||||
processed_node_types = []
|
|
||||||
|
|
||||||
for content in contents:
|
for content in contents:
|
||||||
try:
|
try:
|
||||||
@ -134,7 +130,6 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths.append([])
|
processed_depths.append([])
|
||||||
processed_sibling_idxs.append([])
|
processed_sibling_idxs.append([])
|
||||||
processed_node_texts.append([])
|
processed_node_texts.append([])
|
||||||
processed_node_types.append([])
|
|
||||||
else:
|
else:
|
||||||
processed_input_ids.append(result['input_ids'])
|
processed_input_ids.append(result['input_ids'])
|
||||||
processed_attention_mask.append(result['attention_mask'])
|
processed_attention_mask.append(result['attention_mask'])
|
||||||
@ -142,7 +137,6 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths.append(result['depths'])
|
processed_depths.append(result['depths'])
|
||||||
processed_sibling_idxs.append(result['sibling_idxs'])
|
processed_sibling_idxs.append(result['sibling_idxs'])
|
||||||
processed_node_texts.append(result['node_texts'])
|
processed_node_texts.append(result['node_texts'])
|
||||||
processed_node_types.append(result['node_types'])
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If something unexpected happens
|
# If something unexpected happens
|
||||||
processed_input_ids.append([])
|
processed_input_ids.append([])
|
||||||
@ -151,7 +145,6 @@ def process_batch(batch, tokenizer):
|
|||||||
processed_depths.append([])
|
processed_depths.append([])
|
||||||
processed_sibling_idxs.append([])
|
processed_sibling_idxs.append([])
|
||||||
processed_node_texts.append([])
|
processed_node_texts.append([])
|
||||||
processed_node_types.append([])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'input_ids': processed_input_ids,
|
'input_ids': processed_input_ids,
|
||||||
@ -159,7 +152,6 @@ def process_batch(batch, tokenizer):
|
|||||||
'tokens': processed_tokens,
|
'tokens': processed_tokens,
|
||||||
'depths': processed_depths,
|
'depths': processed_depths,
|
||||||
'sibling_idxs': processed_sibling_idxs,
|
'sibling_idxs': processed_sibling_idxs,
|
||||||
'node_types': processed_node_types,
|
|
||||||
'node_texts': processed_node_texts
|
'node_texts': processed_node_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,32 +1,44 @@
|
|||||||
|
import os
|
||||||
import wandb
|
import wandb
|
||||||
import json
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datasets import load_from_disk, DatasetDict
|
from datasets import load_dataset, DatasetDict
|
||||||
from transformers import (
|
from transformers import (
|
||||||
RobertaConfig,
|
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
RobertaForMaskedLM,
|
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
Trainer,
|
Trainer,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
AutoModelForMaskedLM
|
AutoModelForMaskedLM
|
||||||
)
|
)
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tree_codebert import TreeCodeBERTForPreTraining
|
|
||||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MonitoringTrainer(Trainer):
|
||||||
|
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||||
|
# Perform the regular training step
|
||||||
|
outputs = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
|
# Log gradient norms
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
grad_norm = param.grad.norm(2).item()
|
||||||
|
wandb.log({f'grad_norm/{name}': grad_norm})
|
||||||
|
|
||||||
|
# Log weight norms
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
weight_norm = param.data.norm(2).item()
|
||||||
|
wandb.log({f'weight_norm/{name}': weight_norm})
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -38,60 +50,57 @@ def set_seed(seed: int):
|
|||||||
|
|
||||||
def load_config(config_path: Path) -> dict:
|
def load_config(config_path: Path) -> dict:
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
return json.load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
def main():
|
def initialize_wandb(config, name, files_to_save):
|
||||||
# Setup paths
|
wandb.init(project=config['experiment']['wandb_project'], config=config, name=name)
|
||||||
current_dir = Path(__file__).parent
|
for file in files_to_save:
|
||||||
config = load_config(current_dir / 'config.json')
|
wandb.save(file)
|
||||||
data_dir = Path(config['data_dir'])
|
|
||||||
output_dir = Path(config['output_dir'])
|
|
||||||
|
|
||||||
# Set seed
|
def prepare_dataset(config, cache_dir):
|
||||||
set_seed(config['seed'])
|
dataset = load_dataset(config['data']['source'], split='train', num_proc=16, cache_dir=cache_dir)
|
||||||
|
if config['data']['num_samples'] > 0:
|
||||||
# Initialize W&B
|
dataset = dataset.select(range(config['data']['num_samples']))
|
||||||
wandb.init(project='codeparrot-starencoder-no-comments', config=config, name=config['run_name'])
|
train_testvalid = dataset.train_test_split(test_size=config['data']['test_size'] + config['data']['valid_size'])
|
||||||
|
test_valid = train_testvalid['test'].train_test_split(
|
||||||
# Upload the training files to W&B
|
test_size=config['data']['valid_size'] / (config['data']['test_size'] + config['data']['valid_size']),
|
||||||
wandb.save(__file__)
|
seed=config['training']['seed']
|
||||||
wandb.save(Path(__file__).parent / 'config.json')
|
)
|
||||||
if config['extra_embeddings']:
|
dataset = DatasetDict({
|
||||||
wandb.save(current_dir / 'tree_starencoder.py')
|
'train': train_testvalid['train'],
|
||||||
|
'test': test_valid['test'],
|
||||||
if 'CodeSearchNet' in config['data_dir']:
|
'valid': test_valid['train'],
|
||||||
dataset = DatasetDict({
|
})
|
||||||
'train': load_from_disk(data_dir / 'train'),
|
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['input_ids', 'attention_mask']]
|
||||||
'valid': load_from_disk(data_dir / 'valid'),
|
if config['model']['extra_embeddings']:
|
||||||
'test': load_from_disk(data_dir / 'test')
|
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
|
||||||
})
|
|
||||||
else:
|
|
||||||
dataset = load_from_disk(data_dir)
|
|
||||||
if config['num_samples'] > 0:
|
|
||||||
dataset = dataset.select(range(config['num_samples']))
|
|
||||||
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
|
||||||
test_valid = train_testvalid['test'].train_test_split(
|
|
||||||
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
|
|
||||||
seed=config['seed']
|
|
||||||
)
|
|
||||||
dataset = DatasetDict({
|
|
||||||
'train': train_testvalid['train'],
|
|
||||||
'test': test_valid['test'],
|
|
||||||
'valid': test_valid['train'],
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# Continue with the rest of processing
|
|
||||||
columns_to_remove = dataset['train'].column_names
|
|
||||||
columns_to_remove.remove('input_ids')
|
|
||||||
columns_to_remove.remove('attention_mask')
|
|
||||||
if config['extra_embeddings']:
|
|
||||||
columns_to_remove.remove('depths')
|
|
||||||
columns_to_remove.remove('sibling_idxs')
|
|
||||||
dataset = dataset.remove_columns(columns_to_remove)
|
dataset = dataset.remove_columns(columns_to_remove)
|
||||||
logger.info(f'Loaded dataset:\n{dataset}')
|
return dataset
|
||||||
|
|
||||||
# Initialize model from scratch
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Training script for TreeStarEncoder')
|
||||||
|
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
config_path = Path(args.config)
|
||||||
|
config = load_config(config_path)
|
||||||
|
cache_dir = Path(config['experiment']['cache_dir'])
|
||||||
|
output_dir = Path('./outputs') / config_path.stem
|
||||||
|
|
||||||
|
if config['experiment']['use_wandb']:
|
||||||
|
os.environ['WANDB_MODE'] = 'online'
|
||||||
|
initialize_wandb(config, config_path.stem, [__file__, args.config, current_dir / 'tree_starencoder.py'])
|
||||||
|
else:
|
||||||
|
os.environ['WANDB_MODE'] = 'offline'
|
||||||
|
logger.info('Wandb is not used.')
|
||||||
|
|
||||||
|
set_seed(config['training']['seed'])
|
||||||
|
|
||||||
|
dataset = prepare_dataset(config, cache_dir)
|
||||||
|
logger.info(f'Dataset sizes - Train: {len(dataset["train"])}, Valid: {len(dataset["valid"])}, Test: {len(dataset["test"])}')
|
||||||
|
logger.info(f'Dataset columns: {dataset["train"].column_names}')
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
||||||
if tokenizer.mask_token is None:
|
if tokenizer.mask_token is None:
|
||||||
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
||||||
@ -101,41 +110,40 @@ def main():
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info("Set padding token to be the same as the EOS token.")
|
logger.info("Set padding token to be the same as the EOS token.")
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
|
model_config = AutoConfig.from_pretrained('bigcode/starencoder', cache_dir=cache_dir)
|
||||||
if config['extra_embeddings']:
|
if config['model']['extra_embeddings']:
|
||||||
model = TreeStarEncoderForPreTraining(model_config)
|
model = TreeStarEncoderForPreTraining(model_config, yaml_config=config)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForMaskedLM.from_config(model_config)
|
model = AutoModelForMaskedLM.from_config(model_config)
|
||||||
logger.info(f'Loaded model: {model.__class__.__name__}')
|
logger.info(f'Loaded model: {model.__class__.__name__}')
|
||||||
|
|
||||||
# Setup training arguments
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=str(output_dir),
|
output_dir=str(output_dir),
|
||||||
per_device_train_batch_size=config['batch_size'],
|
per_device_train_batch_size=config['training']['batch_size'],
|
||||||
per_device_eval_batch_size=config['batch_size'],
|
per_device_eval_batch_size=config['training']['batch_size'],
|
||||||
learning_rate=config['learning_rate'],
|
learning_rate=config['training']['learning_rate'],
|
||||||
weight_decay=config['weight_decay'],
|
weight_decay=config['training']['weight_decay'],
|
||||||
num_train_epochs=config['epochs'],
|
num_train_epochs=config['training']['epochs'],
|
||||||
warmup_steps=config['warmup_steps'],
|
warmup_steps=config['training']['warmup_steps'],
|
||||||
max_grad_norm=config['max_grad_norm'],
|
max_grad_norm=config['training']['max_grad_norm'],
|
||||||
logging_steps=config['logging_steps'],
|
logging_steps=config['evaluation']['logging_steps'],
|
||||||
eval_steps=config['eval_every'],
|
eval_steps=config['evaluation']['eval_every'],
|
||||||
save_steps=config['eval_every'],
|
save_steps=config['evaluation']['eval_every'],
|
||||||
eval_strategy='steps',
|
eval_strategy='steps',
|
||||||
save_strategy='steps',
|
save_strategy='steps',
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
report_to='wandb',
|
report_to='wandb' if config['experiment']['use_wandb'] else None,
|
||||||
run_name=config['run_name'],
|
run_name=config_path.stem,
|
||||||
seed=config['seed'],
|
seed=config['training']['seed'],
|
||||||
fp16=config['fp16'],
|
fp16=config['training']['fp16'],
|
||||||
dataloader_num_workers=8,
|
dataloader_num_workers=8,
|
||||||
gradient_checkpointing=True,
|
gradient_checkpointing=True,
|
||||||
metric_for_best_model='eval_loss',
|
metric_for_best_model='eval_loss',
|
||||||
greater_is_better=False,
|
greater_is_better=False,
|
||||||
|
save_total_limit=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create trainer
|
trainer = MonitoringTrainer(
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=dataset['train'],
|
train_dataset=dataset['train'],
|
||||||
@ -143,26 +151,25 @@ def main():
|
|||||||
data_collator=DataCollatorForLanguageModeling(
|
data_collator=DataCollatorForLanguageModeling(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
mlm=True,
|
mlm=True,
|
||||||
mlm_probability=config['mlm_probability']
|
mlm_probability=config['data']['mlm_probability']
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train
|
|
||||||
logger.info('Starting training...')
|
logger.info('Starting training...')
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
logger.info('Training completed.')
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
logger.info('Evaluating on test set...')
|
logger.info('Evaluating on test set...')
|
||||||
eval_results = trainer.evaluate(dataset['test'])
|
eval_results = trainer.evaluate(dataset['test'])
|
||||||
|
logger.info(f'Evaluation results: {eval_results}')
|
||||||
wandb.log({'test_loss': eval_results['eval_loss']})
|
wandb.log({'test_loss': eval_results['eval_loss']})
|
||||||
logger.info(f'Test loss: {eval_results["eval_loss"]}')
|
|
||||||
|
|
||||||
# Save final model
|
|
||||||
logger.info('Saving final model...')
|
logger.info('Saving final model...')
|
||||||
trainer.save_model(output_dir / 'final-model')
|
trainer.save_model(output_dir / 'final-model')
|
||||||
tokenizer.save_pretrained(output_dir / 'final-model')
|
tokenizer.save_pretrained(output_dir / 'final-model')
|
||||||
|
|
||||||
logger.info('Training completed!')
|
# Upload to W&B
|
||||||
|
wandb.save(output_dir / 'final-model/*')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
@ -1,208 +0,0 @@
|
|||||||
import wandb
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from transformers import RobertaConfig, RobertaForMaskedLM
|
|
||||||
|
|
||||||
class TreePositionalEmbedding(nn.Module):
|
|
||||||
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
|
||||||
|
|
||||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.d_model = d_model
|
|
||||||
self.max_depth = max_depth
|
|
||||||
|
|
||||||
# Separate embeddings for different features
|
|
||||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
|
||||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
|
||||||
|
|
||||||
# Improved projection layers
|
|
||||||
self.node_projection = nn.Sequential(
|
|
||||||
nn.Linear(d_model * 2, d_model * 2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(d_model * 2, d_model),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Layer norm for stability
|
|
||||||
self.layer_norm = nn.LayerNorm(d_model)
|
|
||||||
|
|
||||||
self._initialize_embeddings()
|
|
||||||
|
|
||||||
def _initialize_embeddings(self):
|
|
||||||
std = 0.02
|
|
||||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
|
||||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
|
||||||
|
|
||||||
# Initialize projection layers
|
|
||||||
for layer in self.node_projection:
|
|
||||||
if isinstance(layer, nn.Linear):
|
|
||||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
|
||||||
nn.init.zeros_(layer.bias)
|
|
||||||
|
|
||||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
|
||||||
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
|
||||||
Returns:
|
|
||||||
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
|
||||||
"""
|
|
||||||
# Clamp values to max_depth
|
|
||||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
|
||||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
|
||||||
|
|
||||||
# Get embeddings for each feature
|
|
||||||
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
|
||||||
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
|
||||||
|
|
||||||
# Combine features
|
|
||||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
|
||||||
embeddings = self.node_projection(combined)
|
|
||||||
|
|
||||||
# Apply layer norm
|
|
||||||
normalized_embeddings = self.layer_norm(embeddings)
|
|
||||||
|
|
||||||
return normalized_embeddings
|
|
||||||
|
|
||||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
|
||||||
"""CodeBERT model enhanced with tree-structural information."""
|
|
||||||
|
|
||||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
|
||||||
d_model=config.hidden_size,
|
|
||||||
max_depth=max_depth,
|
|
||||||
dropout=config.hidden_dropout_prob
|
|
||||||
)
|
|
||||||
|
|
||||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
|
||||||
|
|
||||||
# Initialize sequential position embeddings with sinusoidal pattern
|
|
||||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
|
||||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
|
||||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
|
||||||
|
|
||||||
# Initialize embedding weights equally
|
|
||||||
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
|
||||||
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
|
||||||
|
|
||||||
# Layer norms for embedding combination
|
|
||||||
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
|
||||||
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
|
||||||
|
|
||||||
def get_normalized_weights(self):
|
|
||||||
"""Get softmaxed weights for embedding combination."""
|
|
||||||
return torch.softmax(self.embedding_weights, dim=0)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
labels: Optional[torch.Tensor] = None,
|
|
||||||
depths: Optional[torch.Tensor] = None,
|
|
||||||
sibling_idxs: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
# Get normalized weights for embedding combination and calculate regularization
|
|
||||||
weights = self.get_normalized_weights()
|
|
||||||
|
|
||||||
# Calculate weight variance regularization
|
|
||||||
# We want weights to remain somewhat balanced, so penalize high variance
|
|
||||||
weight_variance = torch.var(weights)
|
|
||||||
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
|
||||||
|
|
||||||
# Add L2 regularization to prevent any weight from getting too close to 1
|
|
||||||
# This helps maintain a more balanced contribution from each embedding type
|
|
||||||
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
|
||||||
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
|
||||||
|
|
||||||
# Get token embeddings
|
|
||||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
|
||||||
token_embeddings = self.pre_combination_norm(token_embeddings)
|
|
||||||
|
|
||||||
# Get sequential position embeddings
|
|
||||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
|
||||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
|
||||||
|
|
||||||
# Get tree positional embeddings if tree information is provided
|
|
||||||
if depths is not None and sibling_idxs is not None:
|
|
||||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
|
||||||
else:
|
|
||||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
|
||||||
|
|
||||||
# Combine all embeddings using learned weights
|
|
||||||
combined_embeddings = (
|
|
||||||
weights[0] * token_embeddings +
|
|
||||||
weights[1] * tree_embeddings +
|
|
||||||
weights[2] * seq_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
|
||||||
|
|
||||||
# Forward pass through base model
|
|
||||||
outputs = self.roberta(
|
|
||||||
inputs_embeds=combined_embeddings,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
|
||||||
prediction_scores = self.lm_head(sequence_output)
|
|
||||||
|
|
||||||
# Calculate MLM loss if labels are provided
|
|
||||||
masked_lm_loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
masked_lm_loss = loss_fct(
|
|
||||||
prediction_scores.view(-1, self.config.vocab_size),
|
|
||||||
labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add regularization losses to final loss
|
|
||||||
if masked_lm_loss is not None:
|
|
||||||
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
|
||||||
else:
|
|
||||||
final_loss = weight_reg_loss + l2_reg_loss
|
|
||||||
else:
|
|
||||||
final_loss = None
|
|
||||||
|
|
||||||
# Prepare embedding weights for logging
|
|
||||||
weights_cpu = weights.detach().cpu()
|
|
||||||
embedding_weights = {
|
|
||||||
"token": weights_cpu[0].item(),
|
|
||||||
"tree": weights_cpu[1].item(),
|
|
||||||
"sequential": weights_cpu[2].item()
|
|
||||||
}
|
|
||||||
reg_metrics = {
|
|
||||||
"weight_variance": weight_variance.item(),
|
|
||||||
"max_weight_penalty": max_weight_penalty.item(),
|
|
||||||
"weight_reg_loss": weight_reg_loss.item(),
|
|
||||||
"l2_reg_loss": l2_reg_loss.item()
|
|
||||||
}
|
|
||||||
|
|
||||||
wandb.log(
|
|
||||||
{f"embedding_weights/{key}": value for key, value in embedding_weights.items()},
|
|
||||||
step=kwargs.get("global_step", None)
|
|
||||||
)
|
|
||||||
wandb.log(
|
|
||||||
{f"regularization_metrics/{key}": value for key, value in reg_metrics.items()},
|
|
||||||
step=kwargs.get("global_step", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"loss": final_loss,
|
|
||||||
"logits": prediction_scores,
|
|
||||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
|
||||||
"attentions": outputs.attentions if output_attentions else None,
|
|
||||||
}
|
|
@ -1,48 +1,107 @@
|
|||||||
import wandb
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from transformers import AutoConfig, BertForMaskedLM
|
from transformers import AutoConfig, BertForMaskedLM, GenerationMixin
|
||||||
|
|
||||||
from tree_codebert import TreePositionalEmbedding
|
class TreePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
self.node_projection = nn.Sequential(
|
||||||
|
nn.Linear(d_model * 2, d_model * 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(d_model * 2, d_model),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.layer_norm = nn.LayerNorm(d_model)
|
||||||
|
self._initialize_embeddings()
|
||||||
|
|
||||||
|
def _initialize_embeddings(self):
|
||||||
|
std = 0.02
|
||||||
|
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||||
|
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||||
|
for layer in self.node_projection:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||||
|
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||||
|
depth_embeddings = self.depth_embedding(depths)
|
||||||
|
sibling_embeddings = self.sibling_embedding(sibling_idxs)
|
||||||
|
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||||
|
embeddings = self.node_projection(combined)
|
||||||
|
return self.layer_norm(embeddings)
|
||||||
|
|
||||||
|
class NewEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word, position and tree-based embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, yaml_config):
|
||||||
|
super().__init__()
|
||||||
|
self.yaml_config = yaml_config
|
||||||
|
|
||||||
|
if self.yaml_config['model']['concat_embeddings']:
|
||||||
|
self.fusion_layer = nn.Sequential(
|
||||||
|
nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||||
|
d_model=config.hidden_size,
|
||||||
|
max_depth=yaml_config['model']['max_depth'],
|
||||||
|
dropout=config.hidden_dropout_prob
|
||||||
|
)
|
||||||
|
|
||||||
|
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, input_ids=None, depths=None, sibling_idxs=None):
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||||
|
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
tree_pos_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||||
|
|
||||||
|
if self.yaml_config['model']['sum_embeddings']:
|
||||||
|
embeddings = inputs_embeds + position_embeddings + tree_pos_embeddings
|
||||||
|
if self.yaml_config['model']['concat_embeddings']:
|
||||||
|
embeddings = torch.cat([inputs_embeds, position_embeddings, tree_pos_embeddings], dim=-1)
|
||||||
|
embeddings = self.fusion_layer(embeddings)
|
||||||
|
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
class TreeStarEncoderForPreTraining(BertForMaskedLM, GenerationMixin):
|
||||||
|
def __init__(self, config: AutoConfig, yaml_config: Dict):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.yaml_config = yaml_config
|
||||||
|
self.embeddings = NewEmbeddings(config, yaml_config)
|
||||||
|
|
||||||
# self.fusion_layer = nn.Sequential(
|
# self.fusion_layer = nn.Sequential(
|
||||||
# nn.Linear(config.hidden_size * 4, config.hidden_size),
|
# nn.Linear(config.hidden_size * 3, config.hidden_size * 3),
|
||||||
# nn.GELU(),
|
# nn.GELU(),
|
||||||
|
# nn.Linear(config.hidden_size * 3, config.hidden_size), # Reduce back to hidden_size
|
||||||
# nn.Dropout(config.hidden_dropout_prob),
|
# nn.Dropout(config.hidden_dropout_prob),
|
||||||
# nn.LayerNorm(config.hidden_size)
|
# nn.LayerNorm(config.hidden_size)
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# Override config to set max_seq_length
|
|
||||||
config.max_position_embeddings = max_seq_length
|
|
||||||
|
|
||||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
|
||||||
d_model=config.hidden_size,
|
|
||||||
max_depth=max_depth,
|
|
||||||
dropout=config.hidden_dropout_prob
|
|
||||||
)
|
|
||||||
|
|
||||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
|
||||||
|
|
||||||
# # Initialize sequential position embeddings with sinusoidal pattern
|
|
||||||
# position = torch.arange(max_seq_length).unsqueeze(1)
|
|
||||||
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
|
||||||
# pe = torch.zeros(max_seq_length, config.hidden_size)
|
|
||||||
# pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
# pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
# self.seq_pos_embeddings.weight.data.copy_(pe)
|
|
||||||
|
|
||||||
# New node type embeddings
|
|
||||||
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
|
||||||
self.norm = nn.LayerNorm(config.hidden_size)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -53,34 +112,10 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
device = input_ids.device
|
embedding_output = self.embeddings(input_ids, depths, sibling_idxs)
|
||||||
|
|
||||||
# Get token embeddings
|
|
||||||
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
# Get sequential position embeddings
|
|
||||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
|
||||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
|
||||||
|
|
||||||
# Get tree positional embeddings
|
|
||||||
if depths is not None and sibling_idxs is not None:
|
|
||||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
|
||||||
else:
|
|
||||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
|
||||||
|
|
||||||
# # Get node type embeddings
|
|
||||||
# node_type_embeddings = self.node_type_embeddings(node_types)
|
|
||||||
|
|
||||||
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
|
||||||
# combined_embeddings = self.fusion_layer(combined)
|
|
||||||
|
|
||||||
# Add the embeddings instead of concatenating
|
|
||||||
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
|
|
||||||
|
|
||||||
combined_embeddings = self.norm(combined_embeddings)
|
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
inputs_embeds=combined_embeddings,
|
inputs_embeds=embedding_output,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -100,6 +135,7 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
return {
|
return {
|
||||||
"loss": masked_lm_loss,
|
"loss": masked_lm_loss,
|
||||||
"logits": prediction_scores,
|
"logits": prediction_scores,
|
||||||
|
"last_hidden_state": sequence_output,
|
||||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||||
"attentions": outputs.attentions if output_attentions else None,
|
"attentions": outputs.attentions if output_attentions else None,
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user