added tree enhanced embeddings into training
This commit is contained in:
parent
d4ccfaac44
commit
9299c56bb1
1402
code/pdm.lock
1402
code/pdm.lock
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,12 @@
|
||||
{
|
||||
"seed": 42,
|
||||
"mlm_probability": 0.15,
|
||||
"batch": 32,
|
||||
"batch": 16,
|
||||
"epochs": 1,
|
||||
"eval_every": 10000,
|
||||
"eval_every": 20000,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.01,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 10000
|
||||
"warmup_steps": 20000,
|
||||
"pause_on_instability": false
|
||||
}
|
@ -13,7 +13,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from datasets import load_dataset, disable_caching, DatasetDict
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
from transformers import (
|
||||
RobertaForMaskedLM,
|
||||
@ -61,14 +61,14 @@ def set_seed(seed: int) -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def setup_wandb(config: Dict[str, Any]) -> None:
|
||||
def setup_wandb(config: Dict[str, Any], exec_file: str = 'train_codebert_mlm.py') -> None:
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
wandb.init(project='codebert-training', name=curr_time, config=config)
|
||||
wandb.init(project='codebert-training', name=f'tree_{curr_time}', config=config)
|
||||
wandb.save('train_codebert_mlm.py')
|
||||
|
||||
def setup_directories(current_dir: Path) -> Path:
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
output_dir: Path = current_dir.parent.parent / 'outputs' / curr_time
|
||||
output_dir: Path = current_dir.parent.parent / 'outputs' / f'tree_{curr_time}'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
@ -130,10 +130,11 @@ def create_dataloaders(
|
||||
|
||||
def setup_model_and_optimizer(
|
||||
config: Dict[str, Any],
|
||||
current_dir: Path
|
||||
models_dir: Path
|
||||
) -> Tuple[PreTrainedModel, AdamW]:
|
||||
os.environ['HF_HOME'] = str(current_dir.parent / 'models')
|
||||
model_config = RobertaConfig.from_pretrained('roberta-base')
|
||||
if not models_dir.exists():
|
||||
logger.info("Downloading the model...")
|
||||
model_config = RobertaConfig.from_pretrained('roberta-base', cache_dir=models_dir)
|
||||
model: PreTrainedModel = RobertaForMaskedLM(model_config)
|
||||
model = torch.compile(model)
|
||||
wandb.watch(model)
|
||||
@ -212,8 +213,6 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo
|
||||
return avg_loss, avg_acc
|
||||
|
||||
def main() -> None:
|
||||
disable_caching()
|
||||
|
||||
current_dir: Path = Path(__file__).parent
|
||||
output_dir: Path = setup_directories(current_dir)
|
||||
config: Dict[str, Any] = load_config(current_dir / 'config.json')
|
||||
@ -229,9 +228,15 @@ def main() -> None:
|
||||
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
||||
logger.info(f'Tokenizer loaded: {tokenizer}')
|
||||
|
||||
######################## Reproducing last training here ########################
|
||||
# Remove first 186513 batches
|
||||
dataset['train'] = dataset['train'].select(range(186_513 * 32, len(dataset['train'])))
|
||||
################################################################################
|
||||
|
||||
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
|
||||
|
||||
model, optimizer = setup_model_and_optimizer(config, current_dir)
|
||||
models_dir: Path = current_dir.parent / 'models' / 'roberta-base'
|
||||
model, optimizer = setup_model_and_optimizer(config, models_dir)
|
||||
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
@ -240,6 +245,16 @@ def main() -> None:
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
######################## Reproducing last training here ########################
|
||||
# Change opitmizer learning rate to 0.00021575814536340852
|
||||
optimizer = AdamW(model.parameters(), lr=0.00021575814536340852, weight_decay=config['weight_decay'])
|
||||
# Set warmup_steps to 0
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
|
||||
# Load the best model weights
|
||||
state_dict = torch.load('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt', weights_only=True, map_location=device)
|
||||
model.load_state_dict(state_dict)
|
||||
################################################################################
|
||||
|
||||
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
494
code/src/train_tree_codebert_mlm.py
Normal file
494
code/src/train_tree_codebert_mlm.py
Normal file
@ -0,0 +1,494 @@
|
||||
import wandb
|
||||
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import AdamW
|
||||
from dataclasses import dataclass
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import ast
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import DatasetDict
|
||||
from transformers import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
sys.setrecursionlimit(3000) # Increase recursion limit
|
||||
|
||||
# Import existing training functionality
|
||||
from train_codebert_mlm import (
|
||||
set_seed, setup_wandb, setup_directories, load_config,
|
||||
setup_device, download_dataset, load_and_prepare_dataset, logger
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ASTNodeInfo:
|
||||
"""Stores structural information about an AST node."""
|
||||
node_type: str
|
||||
start_token_idx: int
|
||||
end_token_idx: int
|
||||
depth: int
|
||||
sibling_pos: int
|
||||
parent_idx: int
|
||||
|
||||
def ast_collate_fn(batch):
|
||||
"""Custom collate function with improved error handling."""
|
||||
# Remove failed parses (where attention_mask is all zeros)
|
||||
valid_batch = [
|
||||
item for item in batch
|
||||
if item['attention_mask'].sum() > 0 and len(item['ast_nodes']) > 0
|
||||
]
|
||||
|
||||
if not valid_batch:
|
||||
# Return minimal batch if no valid items
|
||||
return {
|
||||
'input_ids': torch.zeros((1, 512), dtype=torch.long),
|
||||
'attention_mask': torch.zeros((1, 512), dtype=torch.long),
|
||||
'labels': torch.zeros((1, 512), dtype=torch.long),
|
||||
'ast_nodes': [[]]
|
||||
}
|
||||
|
||||
# Stack tensors
|
||||
input_ids = torch.stack([item['input_ids'] for item in valid_batch])
|
||||
attention_mask = torch.stack([item['attention_mask'] for item in valid_batch])
|
||||
labels = torch.stack([item['labels'] for item in valid_batch])
|
||||
|
||||
# Collect AST nodes
|
||||
ast_nodes = [item['ast_nodes'] for item in valid_batch]
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
'labels': labels,
|
||||
'ast_nodes': ast_nodes
|
||||
}
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Generates tree-aware positional embeddings for code tokens."""
|
||||
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.combine = nn.Linear(d_model * 2, d_model)
|
||||
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
position = torch.arange(self.max_depth).unsqueeze(1).float()
|
||||
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
|
||||
(-math.log(10000.0) / self.d_model))
|
||||
|
||||
pe = torch.zeros(self.max_depth, self.d_model)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
with torch.no_grad():
|
||||
self.depth_embedding.weight.copy_(pe)
|
||||
self.sibling_embedding.weight.copy_(pe)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
|
||||
"""Process batched input with corresponding AST nodes."""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
|
||||
|
||||
# Process each item in the batch
|
||||
for batch_idx in range(batch_size):
|
||||
ast_nodes = ast_nodes_batch[batch_idx]
|
||||
# Process each position in the sequence
|
||||
for i in range(seq_len):
|
||||
containing_nodes = [
|
||||
node for node in ast_nodes
|
||||
if node.start_token_idx <= i < node.end_token_idx
|
||||
]
|
||||
|
||||
if containing_nodes:
|
||||
node = max(containing_nodes, key=lambda n: n.depth)
|
||||
depth = min(node.depth, self.max_depth - 1)
|
||||
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
|
||||
|
||||
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
|
||||
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
|
||||
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
|
||||
|
||||
return embeddings
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
|
||||
|
||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
super().__init__(config)
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# Initialize weights with small random values around 0
|
||||
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
|
||||
self.beta = nn.Parameter(torch.randn(1) * 0.02)
|
||||
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
|
||||
|
||||
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
# Add dropout for regularization
|
||||
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def get_normalized_weights(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute softmax-normalized weights for embedding combination.
|
||||
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
|
||||
"""
|
||||
weights = torch.stack([self.alpha, self.beta, self.gamma])
|
||||
return torch.softmax(weights, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# Move tensors to device
|
||||
device = input_ids.device
|
||||
|
||||
# Get embeddings
|
||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get normalized weights
|
||||
norm_weights = self.get_normalized_weights()
|
||||
|
||||
# Combine embeddings based on presence of AST nodes
|
||||
if ast_nodes is not None:
|
||||
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
|
||||
combined_embeddings = (
|
||||
norm_weights[0] * token_embeddings +
|
||||
norm_weights[1] * tree_embeddings +
|
||||
norm_weights[2] * seq_embeddings
|
||||
)
|
||||
else:
|
||||
# Redistribute tree weight to other components when no AST available
|
||||
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
|
||||
combined_embeddings = (
|
||||
token_seq_weights[0] * token_embeddings +
|
||||
token_seq_weights[1] * seq_embeddings
|
||||
)
|
||||
|
||||
# Apply layer normalization and dropout
|
||||
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
|
||||
combined_embeddings = self.embedding_dropout(combined_embeddings)
|
||||
combined_embeddings = self.final_layer_norm(combined_embeddings)
|
||||
|
||||
# Forward pass through transformer
|
||||
outputs = self.roberta(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
# Calculate loss if labels provided
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
# Get normalized weights for logging
|
||||
norm_weights_cpu = norm_weights.detach().cpu()
|
||||
|
||||
return {
|
||||
"loss": masked_lm_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions,
|
||||
"embedding_weights": {
|
||||
"token": norm_weights_cpu[0].item(),
|
||||
"tree": norm_weights_cpu[1].item(),
|
||||
"sequential": norm_weights_cpu[2].item()
|
||||
}
|
||||
}
|
||||
|
||||
class TreeEnhancedDataset(Dataset):
|
||||
"""Dataset that processes code into tokens and AST nodes with improved error handling."""
|
||||
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.max_ast_size = 1000 # Limit maximum AST size to prevent memory issues
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def create_dummy_data(self) -> Dict[str, Any]:
|
||||
"""Create dummy data for invalid/problematic code samples."""
|
||||
pad_token_id = self.tokenizer.pad_token_id or 0
|
||||
dummy_tensor = torch.full((self.max_length,), pad_token_id, dtype=torch.long)
|
||||
return {
|
||||
'input_ids': dummy_tensor,
|
||||
'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
|
||||
'labels': dummy_tensor.clone(),
|
||||
'ast_nodes': []
|
||||
}
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
try:
|
||||
content: str = self.dataset[idx]['content']
|
||||
|
||||
# Skip extremely long files
|
||||
if len(content) > 50000: # ~1000 lines
|
||||
return self.create_dummy_data()
|
||||
|
||||
# Basic code validation
|
||||
if not content.strip() or '\0' in content:
|
||||
return self.create_dummy_data()
|
||||
|
||||
# Tokenize first
|
||||
encoding = self.tokenizer(
|
||||
content,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse AST with timeout
|
||||
tree = ast.parse(content)
|
||||
|
||||
# Get AST nodes info
|
||||
nodes_info = []
|
||||
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
|
||||
if len(nodes_info) >= self.max_ast_size:
|
||||
return
|
||||
|
||||
current_idx = len(nodes_info)
|
||||
if hasattr(node, 'lineno'):
|
||||
nodes_info.append(ASTNodeInfo(
|
||||
node_type=type(node).__name__,
|
||||
start_token_idx=node.lineno,
|
||||
end_token_idx=getattr(node, 'end_lineno', node.lineno),
|
||||
depth=min(depth, 31), # Limit depth to prevent issues
|
||||
sibling_pos=min(sibling_pos, 31),
|
||||
parent_idx=parent_idx
|
||||
))
|
||||
|
||||
# Only process children if we haven't hit the limit
|
||||
if len(nodes_info) < self.max_ast_size:
|
||||
for i, child in enumerate(ast.iter_child_nodes(node)):
|
||||
visit_node(child, depth + 1, current_idx, i)
|
||||
|
||||
visit_node(tree, 0, -1, 0)
|
||||
|
||||
# If we hit the AST size limit, use dummy data
|
||||
if len(nodes_info) >= self.max_ast_size:
|
||||
return self.create_dummy_data()
|
||||
|
||||
return {
|
||||
'input_ids': encoding['input_ids'].squeeze(0),
|
||||
'attention_mask': encoding['attention_mask'].squeeze(0),
|
||||
'labels': encoding['input_ids'].squeeze(0).clone(),
|
||||
'ast_nodes': nodes_info
|
||||
}
|
||||
|
||||
except (SyntaxError, ValueError, RecursionError, TimeoutError, MemoryError):
|
||||
return self.create_dummy_data()
|
||||
|
||||
except Exception as e:
|
||||
return self.create_dummy_data()
|
||||
|
||||
def create_tree_dataloaders(
|
||||
dataset: DatasetDict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""Create dataloaders with tree-enhanced datasets."""
|
||||
train_dataset = TreeEnhancedDataset(dataset['train'], tokenizer, max_length=512)
|
||||
valid_dataset = TreeEnhancedDataset(dataset['test'], tokenizer, max_length=512)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=True,
|
||||
collate_fn=ast_collate_fn,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=ast_collate_fn,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
|
||||
return train_dataloader, valid_dataloader
|
||||
|
||||
def train_and_evaluate(
|
||||
model: PreTrainedModel,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
output_dir: Path
|
||||
) -> None:
|
||||
"""Training loop with explicit tracking of alpha, beta, gamma weights."""
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
best_valid_loss: float = float('inf')
|
||||
|
||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||
for epoch_idx in range(config['epochs']):
|
||||
model.train()
|
||||
|
||||
for train_idx, train_batch in enumerate(train_dataloader):
|
||||
outputs = model(**train_batch)
|
||||
train_loss = outputs["loss"]
|
||||
|
||||
if train_loss is not None:
|
||||
train_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Get current metrics
|
||||
current_loss = train_loss.item()
|
||||
weights = outputs["embedding_weights"]
|
||||
|
||||
# Update progress bar with all three weights
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({
|
||||
'loss': f"{current_loss:.3f}",
|
||||
'α': f"{weights['token']:.2f}",
|
||||
'β': f"{weights['tree']:.2f}",
|
||||
'γ': f"{weights['sequential']:.2f}"
|
||||
})
|
||||
|
||||
# Log all three weights separately
|
||||
step = train_idx + len(train_dataloader) * epoch_idx
|
||||
wandb.log({
|
||||
'loss': current_loss,
|
||||
'token_weight': weights['token'],
|
||||
'tree_weight': weights['tree'],
|
||||
'sequential_weight': weights['sequential'],
|
||||
'step': step,
|
||||
})
|
||||
|
||||
# Periodic evaluation
|
||||
if train_idx != 0 and train_idx % config['eval_every'] == 0:
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||
|
||||
wandb.log({
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'step': step,
|
||||
})
|
||||
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||
else:
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({'loss': 'N/A'})
|
||||
|
||||
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||
|
||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||
"""Evaluation function with simple metrics."""
|
||||
model.eval()
|
||||
total_loss: float = 0.0
|
||||
total_acc: float = 0.0
|
||||
num_batches: int = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
if outputs["loss"] is not None:
|
||||
total_loss += outputs["loss"].item()
|
||||
logits = outputs["logits"]
|
||||
labels = batch['labels']
|
||||
predictions = logits.argmax(dim=-1)
|
||||
total_acc += (predictions == labels).float().mean().item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss: float = total_loss / max(num_batches, 1)
|
||||
avg_acc: float = total_acc / max(num_batches, 1)
|
||||
|
||||
model.train()
|
||||
return avg_loss, avg_acc
|
||||
|
||||
def main():
|
||||
# Setup identical to original training script
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir)
|
||||
config = load_config(current_dir / 'config.json')
|
||||
|
||||
setup_wandb(config, exec_file='src/train_tree_codebert_mlm.py')
|
||||
set_seed(config['seed'])
|
||||
device = setup_device()
|
||||
|
||||
dataset_dir = current_dir.parent / 'data' / 'the-stack-python'
|
||||
download_dataset(dataset_dir)
|
||||
dataset = load_and_prepare_dataset(dataset_dir, config['seed'])
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
# Create tree-enhanced dataloaders
|
||||
train_dataloader, valid_dataloader = create_tree_dataloaders(dataset, tokenizer, config, device)
|
||||
|
||||
# Initialize tree-enhanced model
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
model = TreeCodeBERTForPreTraining(model_config)
|
||||
model = model.to(device) # Just move to device without compilation
|
||||
# model = torch.compile(model)
|
||||
|
||||
# Optimizer and scheduler setup identical to original
|
||||
optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
|
||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
# Training loop (using original train_and_evaluate function)
|
||||
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user