msc-patryk-bartkowiak/code/train_codebert_mlm.py

255 lines
9.8 KiB
Python

import wandb
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
import os
import random
import datetime
import numpy as np
from datasets import load_dataset, load_from_disk, disable_caching, DatasetDict
from tree_sitter import Language, Parser
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling
from tqdm import tqdm
from utils import remove_docstrings_and_comments_from_code
# Disable caching for datasets
disable_caching()
############################### CONFIG ###############################
dataset_name = 'the-stack-tokenized' # 'the-stack' or 'code-search-net' or 'the-stack-tokenized
remove_comments = False
######################################################################
# Initialize Weights & Biases and output directory
curr_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
wandb.init(project='codebert-training', name=curr_time)
output_dir = f'/home/s452638/magisterka/output/{curr_time}/'
# Save this file to Weights & Biases
wandb.save('train_codebert_mlm.py')
# Create the output directory if it does not exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Set the seed for reproducibility
SEED = 42
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(SEED)
# Set the device for PyTorch (use GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
print('*' * 10, 'Device', '*' * 10)
print(f'Using device: {device}')
if device.type == 'cuda':
print(f'Device name: {torch.cuda.get_device_name()}')
# Load the dataset
if dataset_name == 'the-stack-tokenized':
train_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/train')
valid_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/valid')
test_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/test')
else:
if dataset_name == 'the-stack':
train_data = load_dataset("/work/s452638/datasets/the-stack-python", split="train")
train_data = train_data.rename_column_('content', 'code')
elif dataset_name == 'code-search-net':
train_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/train.jsonl')['train']
valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['valid']
test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['test']
dataset = DatasetDict({'train': train_data, 'valid': valid_data, 'test': test_data})
print('\n\n', '*' * 10, 'Dataset', '*' * 10)
print(dataset)
if remove_comments:
# Build the language library if not already built
Language.build_library('/home/s452638/magisterka/build/my-languages.so', ['/home/s452638/magisterka/vendor/tree-sitter-python'])
# Load the language
PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')
# Initialize the parser
parser = Parser()
parser.set_language(PYTHON_LANGUAGE)
# Remove docstrings and comments from the code
dataset = dataset.map(lambda x: {'code': remove_docstrings_and_comments_from_code(x['code'], parser)}, batched=False, desc='Removing docstrings and comments')
# Load the tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
print('\n\n', '*' * 10, 'Tokenizer', '*' * 10)
print(tokenizer)
if dataset_name != 'the-stack-tokenized':
# Tokenize the dataset
def tokenize_function(examples):
return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')
print('\n\n', '*' * 10, 'Tokenized dataset', '*' * 10)
print(tokenized_datasets)
else:
tokenized_datasets = dataset
# Set data collator for MLM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# Create DataLoaders
batch_size = 64
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
# Initialize a model with random weights based on the configuration for RoBERTa (CodeBERT is based on RoBERTa)
config = RobertaConfig.from_pretrained('roberta-base')
model = RobertaForMaskedLM(config)
model = torch.compile(model)
wandb.watch(model)
print('\n\n', '*' * 10, 'Model', '*' * 10)
print(config)
# Log the model configuration to wandb
wandb.config.update({'model_config': config.to_dict()})
# Set the optimizer and scaler
learning_rate = 5e-4
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scaler = torch.amp.GradScaler()
# Training settings
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
eval_every = 10_000
# Log training settings to wandb
wandb.config.update({
'training_settings': {
'num_epochs': num_epochs,
'num_training_steps': num_training_steps,
'eval_every': eval_every,
'batch_size': batch_size,
'learning_rate': learning_rate,
}
})
# Initialize variables to track validation loss, accuracy, and best model path
valid_acc = 0.0
valid_loss = 0.0
best_valid_loss = float('inf')
# Train the model
print('\n\n', '*' * 10, 'Training', '*' * 10)
model.train()
with tqdm(total=num_training_steps, desc='Training') as pbar:
for epoch_idx in range(num_epochs):
for train_idx, train_batch in enumerate(train_dataloader):
# Forward pass with mixed precision
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
outputs = model(**train_batch)
train_loss = outputs.loss
scaler.scale(train_loss).backward()
# Gradient clipping to prevent exploding gradients
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix({'norm': norm.item(), 'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
# Log metrics to Weights & Biases
wandb.log({
'step': train_idx + len(train_dataloader) * epoch_idx,
'train_loss': train_loss.item(),
'gradient_norm': norm.item(),
'learning_rate': optimizer.param_groups[0]['lr'],
})
# Evaluate the model
if train_idx != 0 and train_idx % eval_every == 0:
model.eval()
valid_loss = 0.0
valid_acc = 0.0
with tqdm(total=len(valid_dataloader), desc='Validation') as pbar_valid:
with torch.no_grad():
for valid_idx, valid_batch in enumerate(valid_dataloader):
# Forward pass with mixed precision for validation
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
outputs = model(**valid_batch)
# Accumulate validation loss and accuracy
valid_loss += outputs.loss.item()
valid_acc += outputs.logits.argmax(dim=-1).eq(valid_batch['labels']).sum().item()
pbar_valid.update(1)
# Compute average validation loss and accuracy
valid_loss /= len(valid_dataloader)
valid_acc /= len(valid_dataloader.dataset)
model.train()
# Log validation metrics to Weights & Biases
wandb.log({
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': train_idx + len(train_dataloader) * epoch_idx,
})
# Update best model if current validation loss is lower
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), output_dir + f'best_model.pt')
print('\n\n', '*' * 10, 'Training results', '*' * 10)
print(f'Best validation loss: {best_valid_loss}')
# Load the best model and evaluate on the test set
print('\n\n', '*' * 10, 'Testing', '*' * 10)
model.load_state_dict(torch.load(output_dir + f'best_model.pt', weights_only=True, map_location=device))
model.eval()
test_loss = 0.0
test_acc = 0.0
with tqdm(total=len(test_dataloader), desc='Testing') as pbar_test:
with torch.no_grad():
for test_idx, test_batch in enumerate(test_dataloader):
# Forward pass with mixed precision for testing
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
outputs = model(**test_batch)
# Accumulate test loss and accuracy
test_loss += outputs.loss.item()
test_acc += outputs.logits.argmax(dim=-1).eq(test_batch['labels']).sum().item()
pbar_test.update(1)
# Compute average test loss and accuracy
test_loss /= len(test_dataloader)
test_acc /= len(test_dataloader.dataset)
# Log test metrics to Weights & Biases
wandb.log({
'test_loss': test_loss,
'test_acc': test_acc,
})
print('\n\n', '*' * 10, 'Test results', '*' * 10)
print(f'Test loss: {test_loss}')
print(f'Test accuracy: {test_acc}')