working star_encoder

This commit is contained in:
Patryk Bartkowiak 2024-12-13 10:58:33 +00:00
parent edadd3ee02
commit 4006b1fdfd
5 changed files with 135 additions and 17 deletions

View File

@ -1,8 +1,8 @@
{
"extra_embeddings": false,
"run_name": "original",
"data_dir": "./data/codeparrot-clean-parsed/",
"output_dir": "./outputs/original",
"extra_embeddings": true,
"run_name": "tree-3",
"data_dir": "./data/codeparrot-clean-parsed-starencoder/",
"output_dir": "./outputs/tree-starencoder-3",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,

View File

@ -122,13 +122,14 @@ def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder'
output_dir.mkdir(parents=True, exist_ok=True)
# Initialize tokenizer and model from scratch
logger.info("Initializing tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
logger.info("Loaded CodeBERT tokenizer")
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loaded StarEncoder tokenizer")
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
@ -142,7 +143,7 @@ def main():
examples[code_column],
padding='max_length',
truncation=True,
max_length=512,
max_length=1024,
return_special_tokens_mask=True
)

View File

@ -5,14 +5,17 @@ from pathlib import Path
from datasets import load_from_disk, DatasetDict
from transformers import (
RobertaConfig,
AutoConfig,
RobertaForMaskedLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
DataCollatorForLanguageModeling,
AutoModelForMaskedLM
)
from tree_model import TreeCodeBERTForPreTraining
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig(
level=logging.INFO,
@ -33,13 +36,13 @@ def main():
output_dir = Path(config['output_dir'])
# Initialize W&B
wandb.init(project='codeparrot', config=config, name=config['run_name'])
wandb.init(project='codeparrot-starencoder', config=config, name=config['run_name'])
# Upload the training files to W&B
wandb.save(__file__)
wandb.save(Path(__file__).parent / 'config.json')
if config['extra_embeddings']:
wandb.save(current_dir / 'tree_model.py')
wandb.save(current_dir / 'tree_starencoder.py')
if 'CodeSearchNet' in config['data_dir']:
dataset = DatasetDict({
@ -48,7 +51,6 @@ def main():
'test': load_from_disk(data_dir / 'test')
})
else:
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
@ -72,12 +74,20 @@ def main():
logger.info(f'Loaded dataset:\n{dataset}')
# Initialize model from scratch
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
tokenizer = AutoTokenizer.from_pretrained('bigcode/starencoder')
if tokenizer.mask_token is None:
tokenizer.add_special_tokens({'mask_token': '<mask>'})
tokenizer.mask_token = '<mask>'
logger.info("Added '<mask>' as the mask token.")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Set padding token to be the same as the EOS token.")
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
if config['extra_embeddings']:
model = TreeCodeBERTForPreTraining(model_config)
model = TreeStarEncoderForPreTraining(model_config)
else:
model = RobertaForMaskedLM(model_config)
model = AutoModelForMaskedLM.from_config(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}')
# Setup training arguments

View File

@ -0,0 +1,107 @@
import wandb
import math
import torch
import torch.nn as nn
from typing import Dict, Optional
from transformers import AutoConfig, BertForMaskedLM
from tree_codebert import TreePositionalEmbedding
class TreeStarEncoderForPreTraining(BertForMaskedLM):
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
super().__init__(config)
self.config = config
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size),
nn.GELU(),
nn.Dropout(config.hidden_dropout_prob),
nn.LayerNorm(config.hidden_size)
)
# Override config to set max_seq_length
config.max_position_embeddings = max_seq_length
self.tree_pos_embeddings = TreePositionalEmbedding(
d_model=config.hidden_size,
max_depth=max_depth,
dropout=config.hidden_dropout_prob
)
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
self.alpha = nn.Parameter(torch.tensor(0.25))
self.beta = nn.Parameter(torch.tensor(0.25))
self.gamma = nn.Parameter(torch.tensor(0.5))
self.norm = nn.LayerNorm(config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
sibling_idxs: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs
) -> Dict[str, torch.Tensor]:
device = input_ids.device
# Get token embeddings
token_embeddings = self.bert.embeddings.word_embeddings(input_ids)
# Get sequential position embeddings
seq_positions = torch.arange(input_ids.size(1), device=device)
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
# Get tree positional embeddings
if depths is not None and sibling_idxs is not None:
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
else:
tree_embeddings = torch.zeros_like(token_embeddings)
combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings], dim=-1)
combined_embeddings = self.fusion_layer(combined)
combined_embeddings = self.norm(combined_embeddings)
outputs = self.bert(
inputs_embeds=combined_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1)
)
wandb.log({
"embeddings/token": self.alpha.item(),
"embeddings/tree": self.beta.item(),
"embeddings/seq": self.gamma.item()
})
return {
"loss": masked_lm_loss,
"logits": prediction_scores,
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
"attentions": outputs.attentions if output_attentions else None,
}