tree-starencoder working (linear layer instead of alpha beta gamma)

This commit is contained in:
Patryk Bartkowiak 2024-12-18 19:48:06 +00:00
parent 4006b1fdfd
commit 27369bc561
4 changed files with 206 additions and 6 deletions

View File

@ -40,3 +40,4 @@ distribution = true
[tool.pdm.scripts]
parse_dataset = {cmd = "src/parse_dataset.py"}
train = {cmd = "src/training.py"}
eval = {cmd = "src/eval_model.py"}

View File

@ -0,0 +1,9 @@
{
"extra_embeddings": false,
"data_dir": "./data/CodeSearchNet-parsed-starencoder/python",
"model_dir": "./outputs/original-starencoder",
"seed": 420,
"mlm_probability": 0.15,
"batch_size": 32,
"fp16": true
}

188
code/src/eval_model.py Normal file
View File

@ -0,0 +1,188 @@
import torch
import json
import logging
import numpy as np
from pathlib import Path
from datasets import load_from_disk
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, DataCollatorForLanguageModeling
from tqdm import tqdm
from typing import Dict
from tree_starencoder import TreeStarEncoderForPreTraining
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f:
return json.load(f)
def compute_metrics(predictions: torch.Tensor, labels: torch.Tensor, mask_positions: torch.Tensor) -> Dict:
"""Compute MLM metrics including accuracy and top-k accuracy."""
# Get predictions only for masked tokens
masked_predictions = predictions[mask_positions] # Shape: [num_masked_tokens, vocab_size]
masked_labels = labels[mask_positions] # Shape: [num_masked_tokens]
# Calculate top-k accuracy using raw logits (before softmax)
top1_acc = (masked_predictions.argmax(dim=-1) == masked_labels).float().mean().item()
top5_acc = (masked_predictions.topk(k=5, dim=-1).indices == masked_labels.unsqueeze(-1)).any(dim=-1).float().mean().item()
# Calculate per-token loss
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
# Don't apply softmax here - CrossEntropyLoss expects raw logits
token_losses = loss_fct(masked_predictions, masked_labels)
return {
'top1_accuracy': top1_acc * 100, # Convert to percentage
'top5_accuracy': top5_acc * 100, # Convert to percentage
'mean_token_loss': token_losses.mean().item(),
'max_token_loss': token_losses.max().item(),
'min_token_loss': token_losses.min().item(),
'std_token_loss': token_losses.std().item(),
}
def evaluate_model(model, dataset, tokenizer, device, batch_size=8, mlm_probability=0.15):
model.eval()
# Set seed for reproducible masking
torch.manual_seed(42)
np.random.seed(42)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=mlm_probability,
)
all_metrics = []
total_loss = 0
total_samples = 0
# Create a DataLoader that applies masking on the fly
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=data_collator,
generator=torch.Generator().manual_seed(42)
)
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
inputs = {k: v.to(device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs['loss']
logits = outputs['logits']
masked_positions = inputs['input_ids'] == tokenizer.mask_token_id
masked_count = masked_positions.sum().item()
# Add debugging for first batch
if batch_idx == 0:
logger.info(f"Number of masked tokens in first batch: {masked_count}")
logger.info(f"Logits shape: {logits.shape}")
logger.info(f"Sample logits min/max: {logits[0,0].min().item():.4f}/{logits[0,0].max().item():.4f}")
if masked_count == 0:
continue
if torch.isnan(logits).any() or torch.isinf(logits).any():
logger.warning(f"Found NaN or Inf in logits in batch {batch_idx}!")
continue
batch_metrics = compute_metrics(logits, inputs['labels'], masked_positions)
all_metrics.append(batch_metrics)
total_loss += loss.item() * inputs['input_ids'].size(0)
total_samples += inputs['input_ids'].size(0)
# Calculate average metrics
avg_metrics = {
k: np.mean([m[k] for m in all_metrics])
for k in all_metrics[0].keys()
}
# Calculate perplexity
avg_loss = total_loss / total_samples
perplexity = torch.exp(torch.tensor(avg_loss)).item()
avg_metrics['perplexity'] = perplexity
avg_metrics['loss'] = avg_loss
return avg_metrics
def main():
# Setup paths
current_dir = Path(__file__).parent
config = load_config(current_dir / 'eval_config.json')
model_dir = Path(config['model_dir']) / 'final-model'
data_dir = Path(config['data_dir'])
results_dir = Path(config['model_dir']) / 'evaluation_results'
results_dir.mkdir(exist_ok=True)
# Load model and tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize model from config
model_config = AutoConfig.from_pretrained(model_dir / 'config.json')
model_config.max_position_embeddings = 1024
if config['extra_embeddings']:
model = TreeStarEncoderForPreTraining(config=model_config, log=False)
else:
model = AutoModelForMaskedLM.from_config(model_config)
# Load weights from safetensors
state_dict = load_file(model_dir / 'model.safetensors')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
# Re-tie the word embeddings and decoder:
model.tie_weights()
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load dataset
dataset = load_from_disk(data_dir / 'test')
# Remove unnecessary columns
columns_to_keep = ['input_ids', 'attention_mask']
if config['extra_embeddings']:
columns_to_keep.extend(['depths', 'sibling_idxs'])
dataset = dataset.remove_columns(
[col for col in dataset.column_names if col not in columns_to_keep]
)
# Evaluate
logger.info('Starting evaluation...')
metrics = evaluate_model(
model=model,
dataset=dataset,
tokenizer=tokenizer,
device=device,
batch_size=config['batch_size'],
mlm_probability=config['mlm_probability'],
)
# Log results
logger.info('Evaluation Results:')
for metric_name, value in metrics.items():
logger.info(f'{metric_name}: {value:.4f}')
# Save results to JSON
results_file = results_dir / 'metrics.json'
with open(results_file, 'w') as f:
json.dump(metrics, f, indent=4)
logger.info(f'Results saved to {results_file}')
logger.info('Evaluation completed!')
if __name__ == '__main__':
main()

View File

@ -9,9 +9,10 @@ from transformers import AutoConfig, BertForMaskedLM
from tree_codebert import TreePositionalEmbedding
class TreeStarEncoderForPreTraining(BertForMaskedLM):
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512, log: bool = True):
super().__init__(config)
self.config = config
self.log = log
self.fusion_layer = nn.Sequential(
nn.Linear(config.hidden_size * 3, config.hidden_size),
@ -93,11 +94,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
labels.view(-1)
)
wandb.log({
"embeddings/token": self.alpha.item(),
"embeddings/tree": self.beta.item(),
"embeddings/seq": self.gamma.item()
})
if self.log:
wandb.log({
"embeddings/token": self.alpha.item(),
"embeddings/tree": self.beta.item(),
"embeddings/seq": self.gamma.item()
})
return {
"loss": masked_lm_loss,