diff --git a/code/src/config.json b/code/src/config.json index f2bea33..790d5dd 100644 --- a/code/src/config.json +++ b/code/src/config.json @@ -1,12 +1,13 @@ { - "extra_embeddings": true, - "run_name": "no-sinusoidal", + "extra_embeddings": false, + "run_name": "original-continued", "data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/", - "output_dir": "./outputs/long-no-comments-starencoder-no-sinusoidal", + "output_dir": "./outputs/no-comments-starencoder-original-2", + "checkpoint": "./outputs/no-comments-starencoder-original/1_epoch_ckpt/", "seed": 420, "mlm_probability": 0.15, "batch_size": 32, - "epochs": 3, + "epochs": 2, "eval_every": 10000, "learning_rate": 5e-4, "weight_decay": 0.1, diff --git a/code/src/training.py b/code/src/training.py index e985b04..fd6bac5 100644 --- a/code/src/training.py +++ b/code/src/training.py @@ -1,23 +1,22 @@ import wandb + import json +import torch +import random import logging +import numpy as np from pathlib import Path +from safetensors.torch import load_file from datasets import load_from_disk, DatasetDict from transformers import ( - RobertaConfig, AutoConfig, - RobertaForMaskedLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, AutoModelForMaskedLM ) -import random -import numpy as np -import torch -from tree_codebert import TreeCodeBERTForPreTraining from tree_starencoder import TreeStarEncoderForPreTraining logging.basicConfig( @@ -55,39 +54,30 @@ def main(): # Upload the training files to W&B wandb.save(__file__) - wandb.save(Path(__file__).parent / 'config.json') + wandb.save(current_dir / 'config.json') if config['extra_embeddings']: wandb.save(current_dir / 'tree_starencoder.py') - - if 'CodeSearchNet' in config['data_dir']: - dataset = DatasetDict({ - 'train': load_from_disk(data_dir / 'train'), - 'valid': load_from_disk(data_dir / 'valid'), - 'test': load_from_disk(data_dir / 'test') - }) - else: - dataset = load_from_disk(data_dir) - if config['num_samples'] > 0: - dataset = dataset.select(range(config['num_samples'])) - train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size']) - test_valid = train_testvalid['test'].train_test_split( - test_size=config['valid_size'] / (config['test_size'] + config['valid_size']), - seed=config['seed'] - ) - dataset = DatasetDict({ - 'train': train_testvalid['train'], - 'test': test_valid['test'], - 'valid': test_valid['train'], - }) + + dataset = load_from_disk(data_dir) + if config['num_samples'] > 0: + dataset = dataset.select(range(config['num_samples'])) + train_testvalid = dataset.train_test_split(test_size=config['test_size'] + config['valid_size']) + test_valid = train_testvalid['test'].train_test_split( + test_size=config['valid_size'] / (config['test_size'] + config['valid_size']), + seed=config['seed'] + ) + dataset = DatasetDict({ + 'train': train_testvalid['train'], + 'test': test_valid['test'], + 'valid': test_valid['train'], + }) # Continue with the rest of processing columns_to_remove = dataset['train'].column_names - columns_to_remove.remove('input_ids') - columns_to_remove.remove('attention_mask') + columns_to_remove = [col for col in columns_to_remove if col not in ['input_ids', 'attention_mask']] if config['extra_embeddings']: - columns_to_remove.remove('depths') - columns_to_remove.remove('sibling_idxs') + columns_to_remove = [col for col in columns_to_remove if col not in ['depths', 'sibling_idxs']] dataset = dataset.remove_columns(columns_to_remove) logger.info(f'Loaded dataset:\n{dataset}') @@ -102,11 +92,19 @@ def main(): logger.info("Set padding token to be the same as the EOS token.") model_config = AutoConfig.from_pretrained('bigcode/starencoder') - if config['extra_embeddings']: - model = TreeStarEncoderForPreTraining(model_config) - else: - model = AutoModelForMaskedLM.from_config(model_config) + model = TreeStarEncoderForPreTraining(model_config) if config['extra_embeddings'] else AutoModelForMaskedLM.from_config(model_config) logger.info(f'Loaded model: {model.__class__.__name__}') + + # Load checkpoint if provided + if config['checkpoint'] is not None: + checkpoint_path = Path(config['checkpoint']) / 'model.safetensors' + logger.info(f'Loading checkpoint from {checkpoint_path}') + state_dict = load_file(checkpoint_path) + model.load_state_dict(state_dict, strict=False) + model.tie_weights() + config['warmup_steps'] = 0 + config['learning_rate'] = 4.8701e-7 + logger.info('Checkpoint loaded successfully.') # Setup training arguments training_args = TrainingArguments( diff --git a/code/src/tree_starencoder.py b/code/src/tree_starencoder.py index 7271b42..943c606 100644 --- a/code/src/tree_starencoder.py +++ b/code/src/tree_starencoder.py @@ -13,12 +13,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM): super().__init__(config) self.config = config - # self.fusion_layer = nn.Sequential( - # nn.Linear(config.hidden_size * 4, config.hidden_size), - # nn.GELU(), - # nn.Dropout(config.hidden_dropout_prob), - # nn.LayerNorm(config.hidden_size) - # ) + self.fusion_layer = nn.Sequential( + nn.Linear(config.hidden_size * 3, config.hidden_size), + nn.GELU(), + nn.Dropout(config.hidden_dropout_prob), + nn.LayerNorm(config.hidden_size) + ) # Override config to set max_seq_length config.max_position_embeddings = max_seq_length @@ -31,13 +31,13 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM): self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size) - # # Initialize sequential position embeddings with sinusoidal pattern - # position = torch.arange(max_seq_length).unsqueeze(1) - # div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size)) - # pe = torch.zeros(max_seq_length, config.hidden_size) - # pe[:, 0::2] = torch.sin(position * div_term) - # pe[:, 1::2] = torch.cos(position * div_term) - # self.seq_pos_embeddings.weight.data.copy_(pe) + # Initialize sequential position embeddings with sinusoidal pattern + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size)) + pe = torch.zeros(max_seq_length, config.hidden_size) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.seq_pos_embeddings.weight.data.copy_(pe) # New node type embeddings self.node_type_embeddings = nn.Embedding(217, config.hidden_size) @@ -72,10 +72,11 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM): # node_type_embeddings = self.node_type_embeddings(node_types) # combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1) - # combined_embeddings = self.fusion_layer(combined) + combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings], dim=-1) + combined_embeddings = self.fusion_layer(combined) # Add the embeddings instead of concatenating - combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings + # combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings combined_embeddings = self.norm(combined_embeddings)