on this commit i continued to train original starencoder model

This commit is contained in:
Patryk Bartkowiak 2025-01-04 21:02:30 +00:00
parent 3b7bc5d6d2
commit f0679ab861
3 changed files with 55 additions and 55 deletions

View File

@ -1,12 +1,13 @@
{
"extra_embeddings": true,
"run_name": "no-sinusoidal",
"extra_embeddings": false,
"run_name": "original-continued",
"data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/",
"output_dir": "./outputs/long-no-comments-starencoder-no-sinusoidal",
"output_dir": "./outputs/no-comments-starencoder-original-2",
"checkpoint": "./outputs/no-comments-starencoder-original/1_epoch_ckpt/",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,
"epochs": 3,
"epochs": 2,
"eval_every": 10000,
"learning_rate": 5e-4,
"weight_decay": 0.1,

View File

@ -1,23 +1,22 @@
import wandb
import json
import torch
import random
import logging
import numpy as np
from pathlib import Path
from safetensors.torch import load_file
from datasets import load_from_disk, DatasetDict
from transformers import (
RobertaConfig,
AutoConfig,
RobertaForMaskedLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
AutoModelForMaskedLM
)
import random
import numpy as np
import torch
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig(
@ -55,39 +54,30 @@ def main():
# Upload the training files to W&B
wandb.save(__file__)
wandb.save(Path(__file__).parent / 'config.json')
wandb.save(current_dir / 'config.json')
if config['extra_embeddings']:
wandb.save(current_dir / 'tree_starencoder.py')
if 'CodeSearchNet' in config['data_dir']:
dataset = DatasetDict({
'train': load_from_disk(data_dir / 'train'),
'valid': load_from_disk(data_dir / 'valid'),
'test': load_from_disk(data_dir / 'test')
})
else:
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split(
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
seed=config['seed']
)
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
dataset = load_from_disk(data_dir)
if config['num_samples'] > 0:
dataset = dataset.select(range(config['num_samples']))
train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size'])
test_valid = train_testvalid['test'].train_test_split(
test_size=config['valid_size'] / (config['test_size'] + config['valid_size']),
seed=config['seed']
)
dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train'],
})
# Continue with the rest of processing
columns_to_remove = dataset['train'].column_names
columns_to_remove.remove('input_ids')
columns_to_remove.remove('attention_mask')
columns_to_remove = [col for col in columns_to_remove if col not in ['input_ids', 'attention_mask']]
if config['extra_embeddings']:
columns_to_remove.remove('depths')
columns_to_remove.remove('sibling_idxs')
columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']]
dataset = dataset.remove_columns(columns_to_remove)
logger.info(f'Loaded dataset:\n{dataset}')
@ -102,11 +92,19 @@ def main():
logger.info("Set padding token to be the same as the EOS token.")
model_config = AutoConfig.from_pretrained('bigcode/starencoder')
if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(model_config)
else:
model = AutoModelForMaskedLM.from_config(model_config)
model = TreeStarEncoderForPreTraining(model_config) if config['extra_embeddings'] else AutoModelForMaskedLM.from_config(model_config)
logger.info(f'Loaded model: {model.__class__.__name__}')
# Load checkpoint if provided
if config['checkpoint'] is not None:
checkpoint_path = Path(config['checkpoint']) / 'model.safetensors'
logger.info(f'Loading checkpoint from {checkpoint_path}')
state_dict = load_file(checkpoint_path)
model.load_state_dict(state_dict, strict=False)
model.tie_weights()
config['warmup_steps'] = 0
config['learning_rate'] = 4.8701e-7
logger.info('Checkpoint loaded successfully.')
# Setup training arguments
training_args = TrainingArguments(

View File

@ -13,12 +13,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
super().__init__(config)
self.config = config
# self.fusion_layer = nn.Sequential(
# nn.Linear(config.hidden_size * 4, config.hidden_size),
# nn.GELU(),
# nn.Dropout(config.hidden_dropout_prob),
# nn.LayerNorm(config.hidden_size)
# )
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size),
nn.GELU(),
nn.Dropout(config.hidden_dropout_prob),
nn.LayerNorm(config.hidden_size)
)
# Override config to set max_seq_length
config.max_position_embeddings = max_seq_length
@ -31,13 +31,13 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
# # Initialize sequential position embeddings with sinusoidal pattern
# position = torch.arange(max_seq_length).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
# pe = torch.zeros(max_seq_length, config.hidden_size)
# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# self.seq_pos_embeddings.weight.data.copy_(pe)
# Initialize sequential position embeddings with sinusoidal pattern
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
pe = torch.zeros(max_seq_length, config.hidden_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.seq_pos_embeddings.weight.data.copy_(pe)
# New node type embeddings
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
@ -72,10 +72,11 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
# node_type_embeddings = self.node_type_embeddings(node_types)
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
# combined_embeddings = self.fusion_layer(combined)
combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings], dim=-1)
combined_embeddings = self.fusion_layer(combined)
# Add the embeddings instead of concatenating
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
# combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
combined_embeddings = self.norm(combined_embeddings)