Compare commits
4 Commits
runpod-exp
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
3b93a7cc8a | ||
|
76a89dc236 | ||
|
eed2096400 | ||
|
f0679ab861 |
@ -20,9 +20,11 @@ pdm install
|
|||||||
```
|
```
|
||||||
### 4. Run training code
|
### 4. Run training code
|
||||||
```bash
|
```bash
|
||||||
pdm run_training
|
pdm train
|
||||||
```
|
```
|
||||||
or
|
|
||||||
|
## Required secrets
|
||||||
```
|
```
|
||||||
pdm run src/train_codebert_mlm.py
|
export HF_TOKEN=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
|
||||||
|
export WANDB_API_KEY=313671f10f2a389b3171b32da8d4abdad91aaa7c
|
||||||
```
|
```
|
@ -1,21 +1,21 @@
|
|||||||
{
|
{
|
||||||
"project": "runpod",
|
"extra_embeddings": true,
|
||||||
"run_name": "original",
|
"run_name": "tree-continued",
|
||||||
"dataset": "patrykbart/codeparrot-clean-no-comments-starencoder-small",
|
"data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/",
|
||||||
"output_dir": "./outputs/long-no-comments-starencoder-original",
|
"output_dir": "./outputs/no-comments-starencoder-tree-2",
|
||||||
"extra_embeddings": false,
|
"checkpoint": null,
|
||||||
"seed": 420,
|
"seed": 420,
|
||||||
"mlm_probability": 0.15,
|
"mlm_probability": 0.15,
|
||||||
"batch_size": 192,
|
"batch_size": 32,
|
||||||
"epochs": 3,
|
"epochs": 3,
|
||||||
"eval_every": 2500,
|
"eval_every": 10000,
|
||||||
"learning_rate": 5e-4,
|
"learning_rate": 5e-4,
|
||||||
"weight_decay": 0.1,
|
"weight_decay": 0.1,
|
||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
"warmup_steps": 500,
|
"warmup_steps": 1000,
|
||||||
"bf16": true,
|
"fp16": true,
|
||||||
"logging_steps": 100,
|
"logging_steps": 100,
|
||||||
"valid_size": 0.05,
|
"valid_size": 0.05,
|
||||||
"test_size": 0.05,
|
"test_size": 0.05,
|
||||||
"num_samples": -1
|
"num_samples": -1
|
||||||
}
|
}
|
118
code/src/encode_classes.py
Normal file
118
code/src/encode_classes.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
from pathlib import Path
|
||||||
|
from datasets import load_from_disk
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_node_types_from_json(json_path: Path):
|
||||||
|
"""
|
||||||
|
Load node types from the Tree-sitter grammar's `node_types.json` and include UNK as the 0 index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_path (Path): Path to the `node_types.json` file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A mapping from node type strings to unique integer IDs.
|
||||||
|
"""
|
||||||
|
if not json_path.exists():
|
||||||
|
raise FileNotFoundError(f"{json_path} not found.")
|
||||||
|
|
||||||
|
logger.info(f"Loading node types from {json_path}...")
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
node_types_data = json.load(f)
|
||||||
|
|
||||||
|
# Extract all unique "type" entries
|
||||||
|
node_types = set()
|
||||||
|
|
||||||
|
def extract_types(data):
|
||||||
|
if isinstance(data, list):
|
||||||
|
for item in data:
|
||||||
|
extract_types(item)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
if "type" in data and isinstance(data["type"], str):
|
||||||
|
node_types.add(data["type"])
|
||||||
|
for key, value in data.items():
|
||||||
|
extract_types(value)
|
||||||
|
|
||||||
|
extract_types(node_types_data)
|
||||||
|
|
||||||
|
# Create mapping and add 'UNK' at index 0
|
||||||
|
node_type2id = {"<UNK>": 0}
|
||||||
|
for i, node_type in enumerate(sorted(node_types), start=1):
|
||||||
|
node_type2id[node_type] = i
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(node_type2id)} node types, including UNK.")
|
||||||
|
return node_type2id
|
||||||
|
|
||||||
|
def encode_node_types(examples, node_type2id):
|
||||||
|
"""
|
||||||
|
Batched function to replace node type strings with their integer IDs using a preloaded mapping.
|
||||||
|
"""
|
||||||
|
encoded_node_types = []
|
||||||
|
for node_list in examples["node_types"]:
|
||||||
|
try:
|
||||||
|
encoded_node_list = [node_type2id[nt] if nt is not None and nt != 'ERROR' else node_type2id['<UNK>'] for nt in node_list]
|
||||||
|
encoded_node_types.append(encoded_node_list)
|
||||||
|
except KeyError as e:
|
||||||
|
raise KeyError(f"Unknown node type encountered: {e}")
|
||||||
|
examples["node_types_encoded"] = encoded_node_types
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main script to load, process, and save a dataset with node types encoded as integers.
|
||||||
|
"""
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# 1. Setup paths & load dataset
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
input_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-padded"
|
||||||
|
output_dir = current_dir.parent / "data" / "codeparrot-clean-parsed-starencoder-classes-encoded"
|
||||||
|
node_types_path = current_dir / "node_types.json"
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Loading dataset from {input_dir}...")
|
||||||
|
dataset = load_from_disk(str(input_dir))
|
||||||
|
logger.info("Dataset loaded.")
|
||||||
|
|
||||||
|
# Determine number of processes to use
|
||||||
|
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||||
|
logger.info(f"Using {num_proc} processes.")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# 2. Load node types from JSON
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
node_type2id = load_node_types_from_json(node_types_path)
|
||||||
|
logger.info(f"Loaded {len(node_type2id)} node types.")
|
||||||
|
# Save node_type2id to disk
|
||||||
|
with open(output_dir / "node_type2id.json", "w") as f:
|
||||||
|
json.dump(node_type2id, f)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# 3. Convert node types in the dataset to integer IDs
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
logger.info("Converting node type strings to integer IDs...")
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
lambda examples: encode_node_types(examples, node_type2id),
|
||||||
|
batched=True,
|
||||||
|
num_proc=num_proc,
|
||||||
|
desc="Encoding node types to integer IDs",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# 4. Save the modified dataset to disk
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
logger.info(f"Saving updated dataset to {output_dir}...")
|
||||||
|
dataset.save_to_disk(str(output_dir))
|
||||||
|
logger.info("Dataset saved successfully.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -120,7 +120,7 @@ def main():
|
|||||||
# Setup paths
|
# Setup paths
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
config = load_config(current_dir / 'eval_config.json')
|
config = load_config(current_dir / 'eval_config.json')
|
||||||
model_dir = Path(config['model_dir'])
|
model_dir = Path(config['model_dir']) / 'final-model'
|
||||||
data_dir = Path(config['data_dir'])
|
data_dir = Path(config['data_dir'])
|
||||||
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
||||||
results_dir.mkdir(exist_ok=True)
|
results_dir.mkdir(exist_ok=True)
|
||||||
@ -133,7 +133,7 @@ def main():
|
|||||||
model_config.max_position_embeddings = 1024
|
model_config.max_position_embeddings = 1024
|
||||||
|
|
||||||
if config['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
model = TreeStarEncoderForPreTraining(config=model_config)
|
model = TreeStarEncoderForPreTraining(config=model_config, log=False)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForMaskedLM.from_config(model_config)
|
model = AutoModelForMaskedLM.from_config(model_config)
|
||||||
|
|
||||||
|
77
code/src/pad_dataset.py
Normal file
77
code/src/pad_dataset.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from datasets import load_from_disk
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import multiprocessing
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def pad_and_save_dataset(input_dir, output_dir, tokenizer_name='bigcode/starencoder', max_length=512):
|
||||||
|
# Load the processed dataset
|
||||||
|
logger.info(f"Loading processed dataset from {input_dir}...")
|
||||||
|
dataset = load_from_disk(input_dir)
|
||||||
|
logger.info(f"Loaded dataset with {len(dataset)} examples")
|
||||||
|
|
||||||
|
# Initialize tokenizer
|
||||||
|
logger.info("Initializing tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
logger.info("Loaded StarEncoder tokenizer")
|
||||||
|
|
||||||
|
# Define number of processes
|
||||||
|
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||||
|
logger.info(f"Using {num_proc} processes")
|
||||||
|
|
||||||
|
# Define a function to pad the sequences
|
||||||
|
def pad_sequences(batch):
|
||||||
|
# Convert input_ids back to text if necessary
|
||||||
|
texts = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Use the tokenizer's __call__ method for padding
|
||||||
|
padded_inputs = tokenizer(
|
||||||
|
texts,
|
||||||
|
padding='max_length',
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors='pt',
|
||||||
|
truncation=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pad other fields with default values
|
||||||
|
padded_depths = [seq + [-1] * (max_length - len(seq)) for seq in batch['depths']]
|
||||||
|
padded_sibling_idxs = [seq + [-1] * (max_length - len(seq)) for seq in batch['sibling_idxs']]
|
||||||
|
padded_node_types = [seq + [None] * (max_length - len(seq)) for seq in batch['node_types']]
|
||||||
|
padded_node_texts = [seq + [''] * (max_length - len(seq)) for seq in batch['node_texts']]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': padded_inputs['input_ids'].tolist(),
|
||||||
|
'attention_mask': padded_inputs['attention_mask'].tolist(),
|
||||||
|
'depths': padded_depths,
|
||||||
|
'sibling_idxs': padded_sibling_idxs,
|
||||||
|
'node_types': padded_node_types,
|
||||||
|
'node_texts': padded_node_texts
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply padding
|
||||||
|
logger.info("Applying padding to dataset...")
|
||||||
|
padded_dataset = dataset.map(
|
||||||
|
pad_sequences,
|
||||||
|
batched=True,
|
||||||
|
desc="Padding dataset",
|
||||||
|
num_proc=num_proc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the padded dataset
|
||||||
|
logger.info(f"Saving padded dataset to {output_dir}...")
|
||||||
|
padded_dataset.save_to_disk(output_dir)
|
||||||
|
logger.info(f"Saved padded dataset to {output_dir}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
input_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes'
|
||||||
|
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes-padded'
|
||||||
|
|
||||||
|
pad_and_save_dataset(input_dir, output_dir)
|
@ -1,8 +1,12 @@
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
import logging
|
import logging
|
||||||
import zipfile
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from safetensors.torch import load_file
|
||||||
from datasets import load_from_disk, DatasetDict, load_dataset
|
from datasets import load_from_disk, DatasetDict, load_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@ -12,11 +16,7 @@ from transformers import (
|
|||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
AutoModelForMaskedLM
|
AutoModelForMaskedLM
|
||||||
)
|
)
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tree_codebert import TreeCodeBERTForPreTraining
|
|
||||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -43,19 +43,22 @@ def main():
|
|||||||
# Setup paths
|
# Setup paths
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
config = load_config(current_dir / 'config.json')
|
config = load_config(current_dir / 'config.json')
|
||||||
|
data_dir = Path(config['data_dir'])
|
||||||
output_dir = Path(config['output_dir'])
|
output_dir = Path(config['output_dir'])
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
set_seed(config['seed'])
|
set_seed(config['seed'])
|
||||||
|
|
||||||
# Initialize W&B and save files
|
# Initialize W&B
|
||||||
wandb.init(project=config['project'], config=config, name=config['run_name'])
|
wandb.init(project='gralinski', config=config, name=config['run_name'])
|
||||||
for file in [__file__, 'config.json', 'tree_starencoder.py']:
|
|
||||||
if config['extra_embeddings'] or file != 'tree_starencoder.py':
|
# Upload the training files to W&B
|
||||||
wandb.save(current_dir / file)
|
wandb.save(__file__)
|
||||||
|
wandb.save(current_dir / 'config.json')
|
||||||
|
if config['extra_embeddings']:
|
||||||
|
wandb.save(current_dir / 'tree_starencoder.py')
|
||||||
|
|
||||||
# Simplified dataset splitting
|
dataset = load_dataset("patrykbart/codeparrot-clean-no-comments-starencoder-small", split='train', num_proc=16, cache_dir=data_dir.parent)
|
||||||
dataset = load_dataset(config['dataset'], split='train')
|
|
||||||
if config['num_samples'] > 0:
|
if config['num_samples'] > 0:
|
||||||
dataset = dataset.select(range(config['num_samples']))
|
dataset = dataset.select(range(config['num_samples']))
|
||||||
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
|
||||||
@ -72,25 +75,36 @@ def main():
|
|||||||
|
|
||||||
# Continue with the rest of processing
|
# Continue with the rest of processing
|
||||||
columns_to_remove = dataset['train'].column_names
|
columns_to_remove = dataset['train'].column_names
|
||||||
columns_to_remove.remove('input_ids')
|
columns_to_remove = [col for col in columns_to_remove if col not in ['input_ids', 'attention_mask']]
|
||||||
columns_to_remove.remove('attention_mask')
|
|
||||||
if config['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
columns_to_remove.remove('depths')
|
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
|
||||||
columns_to_remove.remove('sibling_idxs')
|
|
||||||
dataset = dataset.remove_columns(columns_to_remove)
|
dataset = dataset.remove_columns(columns_to_remove)
|
||||||
logger.info(f'Loaded dataset:\n{dataset}')
|
logger.info(f'Loaded dataset:\n{dataset}')
|
||||||
|
|
||||||
# Simplify tokenizer setup
|
# Initialize model from scratch
|
||||||
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
||||||
tokenizer.add_special_tokens({'mask_token': '<mask>'}) if tokenizer.mask_token is None else None
|
if tokenizer.mask_token is None:
|
||||||
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
|
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
||||||
|
tokenizer.mask_token = '<mask>'
|
||||||
|
logger.info("Added '<mask>' as the mask token.")
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
logger.info("Set padding token to be the same as the EOS token.")
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
|
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
|
||||||
if config['extra_embeddings']:
|
model = TreeStarEncoderForPreTraining(model_config) if config['extra_embeddings'] else AutoModelForMaskedLM.from_config(model_config)
|
||||||
model = TreeStarEncoderForPreTraining(model_config)
|
|
||||||
else:
|
|
||||||
model = AutoModelForMaskedLM.from_config(model_config)
|
|
||||||
logger.info(f'Loaded model: {model.__class__.__name__}')
|
logger.info(f'Loaded model: {model.__class__.__name__}')
|
||||||
|
|
||||||
|
# Load checkpoint if provided
|
||||||
|
if config['checkpoint'] is not None:
|
||||||
|
checkpoint_path = Path(config['checkpoint']) / 'model.safetensors'
|
||||||
|
logger.info(f'Loading checkpoint from {checkpoint_path}')
|
||||||
|
state_dict = load_file(checkpoint_path)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
model.tie_weights()
|
||||||
|
config['warmup_steps'] = 0
|
||||||
|
config['learning_rate'] = 4.8701e-7
|
||||||
|
logger.info('Checkpoint loaded successfully.')
|
||||||
|
|
||||||
# Setup training arguments
|
# Setup training arguments
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
@ -107,12 +121,11 @@ def main():
|
|||||||
save_steps=config['eval_every'],
|
save_steps=config['eval_every'],
|
||||||
eval_strategy='steps',
|
eval_strategy='steps',
|
||||||
save_strategy='steps',
|
save_strategy='steps',
|
||||||
save_total_limit=5,
|
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
report_to='wandb',
|
report_to='wandb',
|
||||||
run_name=config['run_name'],
|
run_name=config['run_name'],
|
||||||
seed=config['seed'],
|
seed=config['seed'],
|
||||||
bf16=config['bf16'],
|
fp16=config['fp16'],
|
||||||
dataloader_num_workers=8,
|
dataloader_num_workers=8,
|
||||||
gradient_checkpointing=True,
|
gradient_checkpointing=True,
|
||||||
metric_for_best_model='eval_loss',
|
metric_for_best_model='eval_loss',
|
||||||
@ -146,14 +159,8 @@ def main():
|
|||||||
logger.info('Saving final model...')
|
logger.info('Saving final model...')
|
||||||
trainer.save_model(output_dir / 'final-model')
|
trainer.save_model(output_dir / 'final-model')
|
||||||
tokenizer.save_pretrained(output_dir / 'final-model')
|
tokenizer.save_pretrained(output_dir / 'final-model')
|
||||||
|
|
||||||
# Zip and upload the final model to W&B
|
|
||||||
with zipfile.ZipFile(output_dir / 'final-model.zip', 'w') as zipf:
|
|
||||||
for file in (output_dir / 'final-model').glob('**/*'):
|
|
||||||
zipf.write(file, arcname=file.name)
|
|
||||||
wandb.save(output_dir / 'final-model.zip')
|
|
||||||
|
|
||||||
logger.info('Training completed!')
|
logger.info('Training completed!')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
@ -13,12 +13,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# self.fusion_layer = nn.Sequential(
|
self.fusion_layer = nn.Sequential(
|
||||||
# nn.Linear(config.hidden_size * 4, config.hidden_size),
|
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
||||||
# nn.GELU(),
|
nn.GELU(),
|
||||||
# nn.Dropout(config.hidden_dropout_prob),
|
nn.Dropout(config.hidden_dropout_prob),
|
||||||
# nn.LayerNorm(config.hidden_size)
|
nn.LayerNorm(config.hidden_size)
|
||||||
# )
|
)
|
||||||
|
|
||||||
# Override config to set max_seq_length
|
# Override config to set max_seq_length
|
||||||
config.max_position_embeddings = max_seq_length
|
config.max_position_embeddings = max_seq_length
|
||||||
@ -31,13 +31,13 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
|
|
||||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||||
|
|
||||||
# # Initialize sequential position embeddings with sinusoidal pattern
|
# Initialize sequential position embeddings with sinusoidal pattern
|
||||||
# position = torch.arange(max_seq_length).unsqueeze(1)
|
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||||
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||||
# pe = torch.zeros(max_seq_length, config.hidden_size)
|
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||||
# pe[:, 0::2] = torch.sin(position * div_term)
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
# pe[:, 1::2] = torch.cos(position * div_term)
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
# self.seq_pos_embeddings.weight.data.copy_(pe)
|
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||||
|
|
||||||
# New node type embeddings
|
# New node type embeddings
|
||||||
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
||||||
@ -72,10 +72,11 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
# node_type_embeddings = self.node_type_embeddings(node_types)
|
# node_type_embeddings = self.node_type_embeddings(node_types)
|
||||||
|
|
||||||
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
||||||
# combined_embeddings = self.fusion_layer(combined)
|
combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings], dim=-1)
|
||||||
|
combined_embeddings = self.fusion_layer(combined)
|
||||||
|
|
||||||
# Add the embeddings instead of concatenating
|
# Add the embeddings instead of concatenating
|
||||||
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
|
# combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
|
||||||
|
|
||||||
combined_embeddings = self.norm(combined_embeddings)
|
combined_embeddings = self.norm(combined_embeddings)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user