redesigned code, added functions and granularity
This commit is contained in:
parent
3e2f9c7711
commit
240c16b495
4
code/.gitignore
vendored
4
code/.gitignore
vendored
@ -160,3 +160,7 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Weights & Biases
|
||||
wandb/
|
||||
outputs/
|
||||
|
2
code/data/.gitignore
vendored
Normal file
2
code/data/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
2
code/models/.gitignore
vendored
Normal file
2
code/models/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
@ -5,7 +5,7 @@
|
||||
groups = ["default"]
|
||||
strategy = ["inherit_metadata"]
|
||||
lock_version = "4.5.0"
|
||||
content_hash = "sha256:ac6621f3bd9193d786ab94f80f8b1711100fe418959f2e131ae03afeab616788"
|
||||
content_hash = "sha256:bf0a0ea826769cf12a84888d394edd8c3c5599c4d369b8b19b75c2fa5e16f5f0"
|
||||
|
||||
[[metadata.targets]]
|
||||
requires_python = "==3.11.*"
|
||||
|
@ -6,13 +6,13 @@ authors = [
|
||||
{name = "Patryk Bartkowiak", email = "patbar15@st.amu.edu.pl"},
|
||||
]
|
||||
dependencies = [
|
||||
"wandb>=0.18.5",
|
||||
"torch>=2.5.0",
|
||||
"tqdm>=4.66.5",
|
||||
"tree-sitter>=0.23.1",
|
||||
"transformers>=4.45.2",
|
||||
"datasets>=3.0.1",
|
||||
"huggingface-hub>=0.26.0",
|
||||
"wandb==0.18.5",
|
||||
"torch==2.5.0",
|
||||
"tqdm==4.66.5",
|
||||
"tree-sitter==0.23.1",
|
||||
"transformers==4.45.2",
|
||||
"datasets==3.0.1",
|
||||
"huggingface-hub==0.26.0",
|
||||
]
|
||||
requires-python = "==3.11.*"
|
||||
readme = "README.md"
|
||||
|
11
code/src/config.json
Normal file
11
code/src/config.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"seed": 42,
|
||||
"mlm_probability": 0.15,
|
||||
"batch": 32,
|
||||
"epochs": 1,
|
||||
"eval_every": 10000,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.01,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 10000
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -1,58 +0,0 @@
|
||||
from datasets import load_dataset, disable_caching
|
||||
from transformers import RobertaTokenizer
|
||||
|
||||
disable_caching()
|
||||
|
||||
|
||||
def visible_print(text):
|
||||
print('\n\n')
|
||||
print('=' * 100)
|
||||
print(text)
|
||||
print('=' * 100)
|
||||
print('\n\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load the dataset
|
||||
train_data = load_dataset('/work/s452638/datasets/the-stack-python', split='train')
|
||||
valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['train']
|
||||
test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['train']
|
||||
|
||||
visible_print('Loaded data')
|
||||
|
||||
# Rename the columns
|
||||
train_data = train_data.rename_column('content', 'code')
|
||||
|
||||
# Remove all the columns except the code
|
||||
train_columns = train_data.column_names
|
||||
valid_columns = valid_data.column_names
|
||||
test_columns = test_data.column_names
|
||||
|
||||
train_columns.remove('code')
|
||||
valid_columns.remove('code')
|
||||
test_columns.remove('code')
|
||||
|
||||
train_data = train_data.remove_columns(train_columns)
|
||||
valid_data = valid_data.remove_columns(valid_columns)
|
||||
test_data = test_data.remove_columns(test_columns)
|
||||
|
||||
visible_print('Removed unnecessary columns')
|
||||
|
||||
# Tokenize the data
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
|
||||
|
||||
train_data = train_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Train] Running tokenizer', num_proc=8)
|
||||
valid_data = valid_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Valid] Running tokenizer', num_proc=8)
|
||||
test_data = test_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Test] Running tokenizer', num_proc=8)
|
||||
|
||||
visible_print('Tokenized data')
|
||||
|
||||
# Save the tokenized data
|
||||
train_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/train')
|
||||
valid_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/valid')
|
||||
test_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/test')
|
||||
|
||||
visible_print('Saved tokenized data')
|
File diff suppressed because one or more lines are too long
@ -1,254 +1,245 @@
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk, disable_caching, DatasetDict
|
||||
from tree_sitter import Language, Parser
|
||||
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from datasets import load_dataset, disable_caching, DatasetDict
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
from transformers import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
get_linear_schedule_with_warmup,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedModel
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import remove_docstrings_and_comments_from_code
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Disable caching for datasets
|
||||
disable_caching()
|
||||
class OnTheFlyTokenizationDataset(Dataset):
|
||||
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
############################### CONFIG ###############################
|
||||
dataset_name = 'the-stack-tokenized' # 'the-stack' or 'code-search-net' or 'the-stack-tokenized
|
||||
remove_comments = False
|
||||
######################################################################
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
# Initialize Weights & Biases and output directory
|
||||
curr_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
wandb.init(project='codebert-training', name=curr_time)
|
||||
output_dir = f'/home/s452638/magisterka/output/{curr_time}/'
|
||||
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
||||
content: str = self.dataset[idx]['content']
|
||||
tokenized = self.tokenizer(
|
||||
content,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
max_length=self.max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
return {
|
||||
'input_ids': tokenized['input_ids'].squeeze(0),
|
||||
'attention_mask': tokenized['attention_mask'].squeeze(0),
|
||||
'labels': tokenized['input_ids'].squeeze(0)
|
||||
}
|
||||
|
||||
# Save this file to Weights & Biases
|
||||
wandb.save('train_codebert_mlm.py')
|
||||
|
||||
# Create the output directory if it does not exist
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Set the seed for reproducibility
|
||||
SEED = 42
|
||||
def set_seed(seed):
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
set_seed(SEED)
|
||||
def setup_wandb(config: Dict[str, Any]) -> None:
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
wandb.init(project='codebert-training', name=curr_time, config=config)
|
||||
wandb.save('train_codebert_mlm.py')
|
||||
|
||||
# Set the device for PyTorch (use GPU if available, otherwise CPU)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
def setup_directories(current_dir: Path) -> Path:
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
output_dir: Path = current_dir.parent.parent / 'outputs' / curr_time
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
def load_config(config_file: Path) -> Dict[str, Any]:
|
||||
with open(config_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def setup_device() -> torch.device:
|
||||
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
torch.set_default_device(device)
|
||||
print('*' * 10, 'Device', '*' * 10)
|
||||
print(f'Using device: {device}')
|
||||
logger.info(f'Using device: {device}')
|
||||
if device.type == 'cuda':
|
||||
print(f'Device name: {torch.cuda.get_device_name()}')
|
||||
logger.info(f'Device name: {torch.cuda.get_device_name()}')
|
||||
torch.set_float32_matmul_precision('high')
|
||||
return device
|
||||
|
||||
# Load the dataset
|
||||
if dataset_name == 'the-stack-tokenized':
|
||||
train_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/train')
|
||||
valid_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/valid')
|
||||
test_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/test')
|
||||
else:
|
||||
if dataset_name == 'the-stack':
|
||||
train_data = load_dataset("/work/s452638/datasets/the-stack-python", split="train")
|
||||
train_data = train_data.rename_column_('content', 'code')
|
||||
elif dataset_name == 'code-search-net':
|
||||
train_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/train.jsonl')['train']
|
||||
def download_dataset(dataset_dir: Path) -> None:
|
||||
if not dataset_dir.exists():
|
||||
logger.info("Downloading the dataset...")
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
|
||||
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
|
||||
for file_name in files_to_download:
|
||||
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
|
||||
logger.info("Dataset downloaded successfully.")
|
||||
|
||||
valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['valid']
|
||||
test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['test']
|
||||
def load_and_prepare_dataset(dataset_dir: Path, seed: int) -> DatasetDict:
|
||||
dataset: DatasetDict = load_dataset(str(dataset_dir), split='train')
|
||||
dataset = dataset.train_test_split(test_size=0.01, seed=seed)
|
||||
logger.info(f'Dataset loaded: {dataset}')
|
||||
return dataset
|
||||
|
||||
dataset = DatasetDict({'train': train_data, 'valid': valid_data, 'test': test_data})
|
||||
print('\n\n', '*' * 10, 'Dataset', '*' * 10)
|
||||
print(dataset)
|
||||
def create_dataloaders(
|
||||
dataset: DatasetDict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
dataset['train'] = OnTheFlyTokenizationDataset(dataset['train'], tokenizer, max_length=512)
|
||||
dataset['test'] = OnTheFlyTokenizationDataset(dataset['test'], tokenizer, max_length=512)
|
||||
|
||||
if remove_comments:
|
||||
# Build the language library if not already built
|
||||
Language.build_library('/home/s452638/magisterka/build/my-languages.so', ['/home/s452638/magisterka/vendor/tree-sitter-python'])
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=config['mlm_probability'])
|
||||
|
||||
# Load the language
|
||||
PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')
|
||||
train_dataloader = DataLoader(
|
||||
dataset['train'],
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
valid_dataloader = DataLoader(
|
||||
dataset['test'],
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
return train_dataloader, valid_dataloader
|
||||
|
||||
# Initialize the parser
|
||||
parser = Parser()
|
||||
parser.set_language(PYTHON_LANGUAGE)
|
||||
|
||||
# Remove docstrings and comments from the code
|
||||
dataset = dataset.map(lambda x: {'code': remove_docstrings_and_comments_from_code(x['code'], parser)}, batched=False, desc='Removing docstrings and comments')
|
||||
|
||||
# Load the tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
||||
print('\n\n', '*' * 10, 'Tokenizer', '*' * 10)
|
||||
print(tokenizer)
|
||||
|
||||
if dataset_name != 'the-stack-tokenized':
|
||||
# Tokenize the dataset
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')
|
||||
print('\n\n', '*' * 10, 'Tokenized dataset', '*' * 10)
|
||||
print(tokenized_datasets)
|
||||
else:
|
||||
tokenized_datasets = dataset
|
||||
|
||||
# Set data collator for MLM
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
|
||||
# Create DataLoaders
|
||||
batch_size = 64
|
||||
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
|
||||
valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
|
||||
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
|
||||
|
||||
# Initialize a model with random weights based on the configuration for RoBERTa (CodeBERT is based on RoBERTa)
|
||||
config = RobertaConfig.from_pretrained('roberta-base')
|
||||
model = RobertaForMaskedLM(config)
|
||||
def setup_model_and_optimizer(
|
||||
config: Dict[str, Any],
|
||||
current_dir: Path
|
||||
) -> Tuple[PreTrainedModel, AdamW]:
|
||||
os.environ['HF_HOME'] = str(current_dir.parent / 'models')
|
||||
model_config = RobertaConfig.from_pretrained('roberta-base')
|
||||
model: PreTrainedModel = RobertaForMaskedLM(model_config)
|
||||
model = torch.compile(model)
|
||||
wandb.watch(model)
|
||||
print('\n\n', '*' * 10, 'Model', '*' * 10)
|
||||
print(config)
|
||||
logger.info(f'Model config: {model_config}')
|
||||
wandb.config.update({'model_config': model_config.to_dict()})
|
||||
|
||||
# Log the model configuration to wandb
|
||||
wandb.config.update({'model_config': config.to_dict()})
|
||||
optimizer: AdamW = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
|
||||
return model, optimizer
|
||||
|
||||
# Set the optimizer and scaler
|
||||
learning_rate = 5e-4
|
||||
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
|
||||
scaler = torch.amp.GradScaler()
|
||||
def train_and_evaluate(
|
||||
model: PreTrainedModel,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
output_dir: Path
|
||||
) -> None:
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
best_valid_loss: float = float('inf')
|
||||
|
||||
# Training settings
|
||||
num_epochs = 1
|
||||
num_training_steps = num_epochs * len(train_dataloader)
|
||||
eval_every = 10_000
|
||||
|
||||
# Log training settings to wandb
|
||||
wandb.config.update({
|
||||
'training_settings': {
|
||||
'num_epochs': num_epochs,
|
||||
'num_training_steps': num_training_steps,
|
||||
'eval_every': eval_every,
|
||||
'batch_size': batch_size,
|
||||
'learning_rate': learning_rate,
|
||||
}
|
||||
})
|
||||
|
||||
# Initialize variables to track validation loss, accuracy, and best model path
|
||||
valid_acc = 0.0
|
||||
valid_loss = 0.0
|
||||
best_valid_loss = float('inf')
|
||||
|
||||
# Train the model
|
||||
print('\n\n', '*' * 10, 'Training', '*' * 10)
|
||||
model.train()
|
||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||
for epoch_idx in range(num_epochs):
|
||||
for epoch_idx in range(config['epochs']):
|
||||
model.train()
|
||||
for train_idx, train_batch in enumerate(train_dataloader):
|
||||
|
||||
# Forward pass with mixed precision
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
||||
outputs = model(**train_batch)
|
||||
train_loss: Tensor = outputs.loss
|
||||
train_loss.backward()
|
||||
|
||||
train_loss = outputs.loss
|
||||
scaler.scale(train_loss).backward()
|
||||
norm: Tensor = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
|
||||
# Gradient clipping to prevent exploding gradients
|
||||
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({'norm': norm.item(), 'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
|
||||
pbar.set_postfix({'train_loss': train_loss.item()})
|
||||
|
||||
# Log metrics to Weights & Biases
|
||||
wandb.log({
|
||||
'step': train_idx + len(train_dataloader) * epoch_idx,
|
||||
'train_loss': train_loss.item(),
|
||||
'gradient_norm': norm.item(),
|
||||
'learning_rate': optimizer.param_groups[0]['lr'],
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
})
|
||||
|
||||
# Evaluate the model
|
||||
if train_idx != 0 and train_idx % eval_every == 0:
|
||||
model.eval()
|
||||
valid_loss = 0.0
|
||||
valid_acc = 0.0
|
||||
if train_idx != 0 and train_idx % config['eval_every'] == 0:
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||
pbar.set_postfix({'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
|
||||
|
||||
with tqdm(total=len(valid_dataloader), desc='Validation') as pbar_valid:
|
||||
with torch.no_grad():
|
||||
for valid_idx, valid_batch in enumerate(valid_dataloader):
|
||||
|
||||
# Forward pass with mixed precision for validation
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
||||
outputs = model(**valid_batch)
|
||||
|
||||
# Accumulate validation loss and accuracy
|
||||
valid_loss += outputs.loss.item()
|
||||
valid_acc += outputs.logits.argmax(dim=-1).eq(valid_batch['labels']).sum().item()
|
||||
pbar_valid.update(1)
|
||||
|
||||
# Compute average validation loss and accuracy
|
||||
valid_loss /= len(valid_dataloader)
|
||||
valid_acc /= len(valid_dataloader.dataset)
|
||||
model.train()
|
||||
|
||||
# Log validation metrics to Weights & Biases
|
||||
wandb.log({
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'step': train_idx + len(train_dataloader) * epoch_idx,
|
||||
})
|
||||
|
||||
# Update best model if current validation loss is lower
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), output_dir + f'best_model.pt')
|
||||
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||
|
||||
print('\n\n', '*' * 10, 'Training results', '*' * 10)
|
||||
print(f'Best validation loss: {best_valid_loss}')
|
||||
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||
|
||||
# Load the best model and evaluate on the test set
|
||||
print('\n\n', '*' * 10, 'Testing', '*' * 10)
|
||||
model.load_state_dict(torch.load(output_dir + f'best_model.pt', weights_only=True, map_location=device))
|
||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
test_acc = 0.0
|
||||
total_loss: float = 0.0
|
||||
total_acc: float = 0.0
|
||||
|
||||
with tqdm(total=len(test_dataloader), desc='Testing') as pbar_test:
|
||||
with torch.no_grad():
|
||||
for test_idx, test_batch in enumerate(test_dataloader):
|
||||
for batch in tqdm(dataloader, desc='Validation'):
|
||||
outputs = model(**batch)
|
||||
total_loss += outputs.loss.item()
|
||||
total_acc += outputs.logits.argmax(dim=-1).eq(batch['labels']).sum().item()
|
||||
|
||||
# Forward pass with mixed precision for testing
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
||||
outputs = model(**test_batch)
|
||||
avg_loss: float = total_loss / len(dataloader)
|
||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
||||
return avg_loss, avg_acc
|
||||
|
||||
# Accumulate test loss and accuracy
|
||||
test_loss += outputs.loss.item()
|
||||
test_acc += outputs.logits.argmax(dim=-1).eq(test_batch['labels']).sum().item()
|
||||
pbar_test.update(1)
|
||||
def main() -> None:
|
||||
disable_caching()
|
||||
|
||||
# Compute average test loss and accuracy
|
||||
test_loss /= len(test_dataloader)
|
||||
test_acc /= len(test_dataloader.dataset)
|
||||
current_dir: Path = Path(__file__).parent
|
||||
output_dir: Path = setup_directories(current_dir)
|
||||
config: Dict[str, Any] = load_config(current_dir / 'config.json')
|
||||
|
||||
# Log test metrics to Weights & Biases
|
||||
wandb.log({
|
||||
'test_loss': test_loss,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
setup_wandb(config)
|
||||
set_seed(config['seed'])
|
||||
device: torch.device = setup_device()
|
||||
|
||||
print('\n\n', '*' * 10, 'Test results', '*' * 10)
|
||||
print(f'Test loss: {test_loss}')
|
||||
print(f'Test accuracy: {test_acc}')
|
||||
dataset_dir: Path = current_dir.parent / 'data' / 'the-stack-python'
|
||||
download_dataset(dataset_dir)
|
||||
dataset: DatasetDict = load_and_prepare_dataset(dataset_dir, config['seed'])
|
||||
|
||||
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
||||
logger.info(f'Tokenizer loaded: {tokenizer}')
|
||||
|
||||
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
|
||||
|
||||
model, optimizer = setup_model_and_optimizer(config, current_dir)
|
||||
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user