working star_encoder
This commit is contained in:
parent
edadd3ee02
commit
4006b1fdfd
@ -1,8 +1,8 @@
|
||||
{
|
||||
"extra_embeddings": false,
|
||||
"run_name": "original",
|
||||
"data_dir": "./data/codeparrot-clean-parsed/",
|
||||
"output_dir": "./outputs/original",
|
||||
"extra_embeddings": true,
|
||||
"run_name": "tree-3",
|
||||
"data_dir": "./data/codeparrot-clean-parsed-starencoder/",
|
||||
"output_dir": "./outputs/tree-starencoder-3",
|
||||
"seed": 420,
|
||||
"mlm_probability": 0.15,
|
||||
"batch_size": 32,
|
||||
|
@ -122,13 +122,14 @@ def main():
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
|
||||
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed'
|
||||
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize tokenizer and model from scratch
|
||||
logger.info("Initializing tokenizer and model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
logger.info("Loaded CodeBERT tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Loaded StarEncoder tokenizer")
|
||||
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
@ -142,7 +143,7 @@ def main():
|
||||
examples[code_column],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
max_length=1024,
|
||||
return_special_tokens_mask=True
|
||||
)
|
||||
|
||||
|
@ -5,14 +5,17 @@ from pathlib import Path
|
||||
from datasets import load_from_disk, DatasetDict
|
||||
from transformers import (
|
||||
RobertaConfig,
|
||||
AutoConfig,
|
||||
RobertaForMaskedLM,
|
||||
AutoTokenizer,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForLanguageModeling
|
||||
DataCollatorForLanguageModeling,
|
||||
AutoModelForMaskedLM
|
||||
)
|
||||
|
||||
from tree_model import TreeCodeBERTForPreTraining
|
||||
from tree_codebert import TreeCodeBERTForPreTraining
|
||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@ -33,13 +36,13 @@ def main():
|
||||
output_dir = Path(config['output_dir'])
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project='codeparrot', config=config, name=config['run_name'])
|
||||
wandb.init(project='codeparrot-starencoder', config=config, name=config['run_name'])
|
||||
|
||||
# Upload the training files to W&B
|
||||
wandb.save(__file__)
|
||||
wandb.save(Path(__file__).parent / 'config.json')
|
||||
if config['extra_embeddings']:
|
||||
wandb.save(current_dir / 'tree_model.py')
|
||||
wandb.save(current_dir / 'tree_starencoder.py')
|
||||
|
||||
if 'CodeSearchNet' in config['data_dir']:
|
||||
dataset = DatasetDict({
|
||||
@ -48,7 +51,6 @@ def main():
|
||||
'test': load_from_disk(data_dir / 'test')
|
||||
})
|
||||
else:
|
||||
|
||||
dataset = load_from_disk(data_dir)
|
||||
if config['num_samples'] > 0:
|
||||
dataset = dataset.select(range(config['num_samples']))
|
||||
@ -72,12 +74,20 @@ def main():
|
||||
logger.info(f'Loaded dataset:\n{dataset}')
|
||||
|
||||
# Initialize model from scratch
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
|
||||
if tokenizer.mask_token is None:
|
||||
tokenizer.add_special_tokens({'mask_token': '<mask>'})
|
||||
tokenizer.mask_token = '<mask>'
|
||||
logger.info("Added '<mask>' as the mask token.")
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Set padding token to be the same as the EOS token.")
|
||||
|
||||
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
|
||||
if config['extra_embeddings']:
|
||||
model = TreeCodeBERTForPreTraining(model_config)
|
||||
model = TreeStarEncoderForPreTraining(model_config)
|
||||
else:
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
model = AutoModelForMaskedLM.from_config(model_config)
|
||||
logger.info(f'Loaded model: {model.__class__.__name__}')
|
||||
|
||||
# Setup training arguments
|
||||
|
107
code/src/tree_starencoder.py
Normal file
107
code/src/tree_starencoder.py
Normal file
@ -0,0 +1,107 @@
|
||||
import wandb
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Optional
|
||||
from transformers import AutoConfig, BertForMaskedLM
|
||||
|
||||
from tree_codebert import TreePositionalEmbedding
|
||||
|
||||
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.fusion_layer = nn.Sequential(
|
||||
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.hidden_dropout_prob),
|
||||
nn.LayerNorm(config.hidden_size)
|
||||
)
|
||||
|
||||
# Override config to set max_seq_length
|
||||
config.max_position_embeddings = max_seq_length
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth,
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
self.alpha = nn.Parameter(torch.tensor(0.25))
|
||||
self.beta = nn.Parameter(torch.tensor(0.25))
|
||||
self.gamma = nn.Parameter(torch.tensor(0.5))
|
||||
|
||||
self.norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
sibling_idxs: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
device = input_ids.device
|
||||
|
||||
# Get token embeddings
|
||||
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
|
||||
|
||||
# Get sequential position embeddings
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get tree positional embeddings
|
||||
if depths is not None and sibling_idxs is not None:
|
||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
else:
|
||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||
|
||||
combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings], dim=-1)
|
||||
combined_embeddings = self.fusion_layer(combined)
|
||||
|
||||
combined_embeddings = self.norm(combined_embeddings)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
wandb.log({
|
||||
"embeddings/token": self.alpha.item(),
|
||||
"embeddings/tree": self.beta.item(),
|
||||
"embeddings/seq": self.gamma.item()
|
||||
})
|
||||
|
||||
return {
|
||||
"loss": masked_lm_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions if output_attentions else None,
|
||||
}
|
Loading…
Reference in New Issue
Block a user