255 lines
9.8 KiB
Python
255 lines
9.8 KiB
Python
import wandb
|
|
|
|
import torch
|
|
from torch.optim import AdamW
|
|
from torch.utils.data import DataLoader
|
|
import os
|
|
import random
|
|
import datetime
|
|
import numpy as np
|
|
from datasets import load_dataset, load_from_disk, disable_caching, DatasetDict
|
|
from tree_sitter import Language, Parser
|
|
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling
|
|
from tqdm import tqdm
|
|
|
|
from utils import remove_docstrings_and_comments_from_code
|
|
|
|
# Disable caching for datasets
|
|
disable_caching()
|
|
|
|
############################### CONFIG ###############################
|
|
dataset_name = 'the-stack-tokenized' # 'the-stack' or 'code-search-net' or 'the-stack-tokenized
|
|
remove_comments = False
|
|
######################################################################
|
|
|
|
# Initialize Weights & Biases and output directory
|
|
curr_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
|
wandb.init(project='codebert-training', name=curr_time)
|
|
output_dir = f'/home/s452638/magisterka/output/{curr_time}/'
|
|
|
|
# Save this file to Weights & Biases
|
|
wandb.save('train_codebert_mlm.py')
|
|
|
|
# Create the output directory if it does not exist
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
|
|
# Set the seed for reproducibility
|
|
SEED = 42
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
set_seed(SEED)
|
|
|
|
# Set the device for PyTorch (use GPU if available, otherwise CPU)
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
torch.set_default_device(device)
|
|
print('*' * 10, 'Device', '*' * 10)
|
|
print(f'Using device: {device}')
|
|
if device.type == 'cuda':
|
|
print(f'Device name: {torch.cuda.get_device_name()}')
|
|
|
|
# Load the dataset
|
|
if dataset_name == 'the-stack-tokenized':
|
|
train_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/train')
|
|
valid_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/valid')
|
|
test_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/test')
|
|
else:
|
|
if dataset_name == 'the-stack':
|
|
train_data = load_dataset("/work/s452638/datasets/the-stack-python", split="train")
|
|
train_data = train_data.rename_column_('content', 'code')
|
|
elif dataset_name == 'code-search-net':
|
|
train_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/train.jsonl')['train']
|
|
|
|
valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['valid']
|
|
test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['test']
|
|
|
|
dataset = DatasetDict({'train': train_data, 'valid': valid_data, 'test': test_data})
|
|
print('\n\n', '*' * 10, 'Dataset', '*' * 10)
|
|
print(dataset)
|
|
|
|
if remove_comments:
|
|
# Build the language library if not already built
|
|
Language.build_library('/home/s452638/magisterka/build/my-languages.so', ['/home/s452638/magisterka/vendor/tree-sitter-python'])
|
|
|
|
# Load the language
|
|
PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')
|
|
|
|
# Initialize the parser
|
|
parser = Parser()
|
|
parser.set_language(PYTHON_LANGUAGE)
|
|
|
|
# Remove docstrings and comments from the code
|
|
dataset = dataset.map(lambda x: {'code': remove_docstrings_and_comments_from_code(x['code'], parser)}, batched=False, desc='Removing docstrings and comments')
|
|
|
|
# Load the tokenizer
|
|
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
|
print('\n\n', '*' * 10, 'Tokenizer', '*' * 10)
|
|
print(tokenizer)
|
|
|
|
if dataset_name != 'the-stack-tokenized':
|
|
# Tokenize the dataset
|
|
def tokenize_function(examples):
|
|
return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
|
|
|
|
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')
|
|
print('\n\n', '*' * 10, 'Tokenized dataset', '*' * 10)
|
|
print(tokenized_datasets)
|
|
else:
|
|
tokenized_datasets = dataset
|
|
|
|
# Set data collator for MLM
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
|
|
|
# Create DataLoaders
|
|
batch_size = 64
|
|
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
|
|
valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
|
|
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
|
|
|
|
# Initialize a model with random weights based on the configuration for RoBERTa (CodeBERT is based on RoBERTa)
|
|
config = RobertaConfig.from_pretrained('roberta-base')
|
|
model = RobertaForMaskedLM(config)
|
|
model = torch.compile(model)
|
|
wandb.watch(model)
|
|
print('\n\n', '*' * 10, 'Model', '*' * 10)
|
|
print(config)
|
|
|
|
# Log the model configuration to wandb
|
|
wandb.config.update({'model_config': config.to_dict()})
|
|
|
|
# Set the optimizer and scaler
|
|
learning_rate = 5e-4
|
|
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
|
|
scaler = torch.amp.GradScaler()
|
|
|
|
# Training settings
|
|
num_epochs = 1
|
|
num_training_steps = num_epochs * len(train_dataloader)
|
|
eval_every = 10_000
|
|
|
|
# Log training settings to wandb
|
|
wandb.config.update({
|
|
'training_settings': {
|
|
'num_epochs': num_epochs,
|
|
'num_training_steps': num_training_steps,
|
|
'eval_every': eval_every,
|
|
'batch_size': batch_size,
|
|
'learning_rate': learning_rate,
|
|
}
|
|
})
|
|
|
|
# Initialize variables to track validation loss, accuracy, and best model path
|
|
valid_acc = 0.0
|
|
valid_loss = 0.0
|
|
best_valid_loss = float('inf')
|
|
|
|
# Train the model
|
|
print('\n\n', '*' * 10, 'Training', '*' * 10)
|
|
model.train()
|
|
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
|
for epoch_idx in range(num_epochs):
|
|
for train_idx, train_batch in enumerate(train_dataloader):
|
|
|
|
# Forward pass with mixed precision
|
|
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
|
outputs = model(**train_batch)
|
|
|
|
train_loss = outputs.loss
|
|
scaler.scale(train_loss).backward()
|
|
|
|
# Gradient clipping to prevent exploding gradients
|
|
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
pbar.update(1)
|
|
pbar.set_postfix({'norm': norm.item(), 'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
|
|
|
|
# Log metrics to Weights & Biases
|
|
wandb.log({
|
|
'step': train_idx + len(train_dataloader) * epoch_idx,
|
|
'train_loss': train_loss.item(),
|
|
'gradient_norm': norm.item(),
|
|
'learning_rate': optimizer.param_groups[0]['lr'],
|
|
})
|
|
|
|
# Evaluate the model
|
|
if train_idx != 0 and train_idx % eval_every == 0:
|
|
model.eval()
|
|
valid_loss = 0.0
|
|
valid_acc = 0.0
|
|
|
|
with tqdm(total=len(valid_dataloader), desc='Validation') as pbar_valid:
|
|
with torch.no_grad():
|
|
for valid_idx, valid_batch in enumerate(valid_dataloader):
|
|
|
|
# Forward pass with mixed precision for validation
|
|
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
|
outputs = model(**valid_batch)
|
|
|
|
# Accumulate validation loss and accuracy
|
|
valid_loss += outputs.loss.item()
|
|
valid_acc += outputs.logits.argmax(dim=-1).eq(valid_batch['labels']).sum().item()
|
|
pbar_valid.update(1)
|
|
|
|
# Compute average validation loss and accuracy
|
|
valid_loss /= len(valid_dataloader)
|
|
valid_acc /= len(valid_dataloader.dataset)
|
|
model.train()
|
|
|
|
# Log validation metrics to Weights & Biases
|
|
wandb.log({
|
|
'valid_loss': valid_loss,
|
|
'valid_acc': valid_acc,
|
|
'step': train_idx + len(train_dataloader) * epoch_idx,
|
|
})
|
|
|
|
# Update best model if current validation loss is lower
|
|
if valid_loss < best_valid_loss:
|
|
best_valid_loss = valid_loss
|
|
torch.save(model.state_dict(), output_dir + f'best_model.pt')
|
|
|
|
print('\n\n', '*' * 10, 'Training results', '*' * 10)
|
|
print(f'Best validation loss: {best_valid_loss}')
|
|
|
|
# Load the best model and evaluate on the test set
|
|
print('\n\n', '*' * 10, 'Testing', '*' * 10)
|
|
model.load_state_dict(torch.load(output_dir + f'best_model.pt', weights_only=True, map_location=device))
|
|
model.eval()
|
|
test_loss = 0.0
|
|
test_acc = 0.0
|
|
|
|
with tqdm(total=len(test_dataloader), desc='Testing') as pbar_test:
|
|
with torch.no_grad():
|
|
for test_idx, test_batch in enumerate(test_dataloader):
|
|
|
|
# Forward pass with mixed precision for testing
|
|
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
|
outputs = model(**test_batch)
|
|
|
|
# Accumulate test loss and accuracy
|
|
test_loss += outputs.loss.item()
|
|
test_acc += outputs.logits.argmax(dim=-1).eq(test_batch['labels']).sum().item()
|
|
pbar_test.update(1)
|
|
|
|
# Compute average test loss and accuracy
|
|
test_loss /= len(test_dataloader)
|
|
test_acc /= len(test_dataloader.dataset)
|
|
|
|
# Log test metrics to Weights & Biases
|
|
wandb.log({
|
|
'test_loss': test_loss,
|
|
'test_acc': test_acc,
|
|
})
|
|
|
|
print('\n\n', '*' * 10, 'Test results', '*' * 10)
|
|
print(f'Test loss: {test_loss}')
|
|
print(f'Test accuracy: {test_acc}')
|