tree-starencoder working (linear layer instead of alpha beta gamma)
This commit is contained in:
parent
4006b1fdfd
commit
27369bc561
@ -40,3 +40,4 @@ distribution = true
|
||||
[tool.pdm.scripts]
|
||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||
train = {cmd = "src/training.py"}
|
||||
eval = {cmd = "src/eval_model.py"}
|
||||
|
9
code/src/eval_config.json
Normal file
9
code/src/eval_config.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"extra_embeddings": false,
|
||||
"data_dir": "./data/CodeSearchNet-parsed-starencoder/python",
|
||||
"model_dir": "./outputs/original-starencoder",
|
||||
"seed": 420,
|
||||
"mlm_probability": 0.15,
|
||||
"batch_size": 32,
|
||||
"fp16": true
|
||||
}
|
188
code/src/eval_model.py
Normal file
188
code/src/eval_model.py
Normal file
@ -0,0 +1,188 @@
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from datasets import load_from_disk
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, DataCollatorForLanguageModeling
|
||||
from tqdm import tqdm
|
||||
from typing import Dict
|
||||
|
||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_config(config_path: Path) -> dict:
|
||||
with open(config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def compute_metrics(predictions: torch.Tensor, labels: torch.Tensor, mask_positions: torch.Tensor) -> Dict:
|
||||
"""Compute MLM metrics including accuracy and top-k accuracy."""
|
||||
# Get predictions only for masked tokens
|
||||
masked_predictions = predictions[mask_positions] # Shape: [num_masked_tokens, vocab_size]
|
||||
masked_labels = labels[mask_positions] # Shape: [num_masked_tokens]
|
||||
|
||||
# Calculate top-k accuracy using raw logits (before softmax)
|
||||
top1_acc = (masked_predictions.argmax(dim=-1) == masked_labels).float().mean().item()
|
||||
top5_acc = (masked_predictions.topk(k=5, dim=-1).indices == masked_labels.unsqueeze(-1)).any(dim=-1).float().mean().item()
|
||||
|
||||
# Calculate per-token loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
# Don't apply softmax here - CrossEntropyLoss expects raw logits
|
||||
token_losses = loss_fct(masked_predictions, masked_labels)
|
||||
|
||||
return {
|
||||
'top1_accuracy': top1_acc * 100, # Convert to percentage
|
||||
'top5_accuracy': top5_acc * 100, # Convert to percentage
|
||||
'mean_token_loss': token_losses.mean().item(),
|
||||
'max_token_loss': token_losses.max().item(),
|
||||
'min_token_loss': token_losses.min().item(),
|
||||
'std_token_loss': token_losses.std().item(),
|
||||
}
|
||||
|
||||
def evaluate_model(model, dataset, tokenizer, device, batch_size=8, mlm_probability=0.15):
|
||||
model.eval()
|
||||
|
||||
# Set seed for reproducible masking
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=mlm_probability,
|
||||
)
|
||||
|
||||
all_metrics = []
|
||||
total_loss = 0
|
||||
total_samples = 0
|
||||
|
||||
# Create a DataLoader that applies masking on the fly
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=data_collator,
|
||||
generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
|
||||
inputs = {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs['loss']
|
||||
logits = outputs['logits']
|
||||
|
||||
masked_positions = inputs['input_ids'] == tokenizer.mask_token_id
|
||||
masked_count = masked_positions.sum().item()
|
||||
|
||||
# Add debugging for first batch
|
||||
if batch_idx == 0:
|
||||
logger.info(f"Number of masked tokens in first batch: {masked_count}")
|
||||
logger.info(f"Logits shape: {logits.shape}")
|
||||
logger.info(f"Sample logits min/max: {logits[0,0].min().item():.4f}/{logits[0,0].max().item():.4f}")
|
||||
|
||||
if masked_count == 0:
|
||||
continue
|
||||
|
||||
if torch.isnan(logits).any() or torch.isinf(logits).any():
|
||||
logger.warning(f"Found NaN or Inf in logits in batch {batch_idx}!")
|
||||
continue
|
||||
|
||||
batch_metrics = compute_metrics(logits, inputs['labels'], masked_positions)
|
||||
all_metrics.append(batch_metrics)
|
||||
|
||||
total_loss += loss.item() * inputs['input_ids'].size(0)
|
||||
total_samples += inputs['input_ids'].size(0)
|
||||
|
||||
# Calculate average metrics
|
||||
avg_metrics = {
|
||||
k: np.mean([m[k] for m in all_metrics])
|
||||
for k in all_metrics[0].keys()
|
||||
}
|
||||
|
||||
# Calculate perplexity
|
||||
avg_loss = total_loss / total_samples
|
||||
perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
||||
|
||||
avg_metrics['perplexity'] = perplexity
|
||||
avg_metrics['loss'] = avg_loss
|
||||
|
||||
return avg_metrics
|
||||
|
||||
def main():
|
||||
# Setup paths
|
||||
current_dir = Path(__file__).parent
|
||||
config = load_config(current_dir / 'eval_config.json')
|
||||
model_dir = Path(config['model_dir']) / 'final-model'
|
||||
data_dir = Path(config['data_dir'])
|
||||
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Load model and tokenizer
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Initialize model from config
|
||||
model_config = AutoConfig.from_pretrained(model_dir / 'config.json')
|
||||
model_config.max_position_embeddings = 1024
|
||||
|
||||
if config['extra_embeddings']:
|
||||
model = TreeStarEncoderForPreTraining(config=model_config, log=False)
|
||||
else:
|
||||
model = AutoModelForMaskedLM.from_config(model_config)
|
||||
|
||||
# Load weights from safetensors
|
||||
state_dict = load_file(model_dir / 'model.safetensors')
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print("Missing keys:", missing_keys)
|
||||
print("Unexpected keys:", unexpected_keys)
|
||||
|
||||
# Re-tie the word embeddings and decoder:
|
||||
model.tie_weights()
|
||||
|
||||
model = model.to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
|
||||
# Load dataset
|
||||
dataset = load_from_disk(data_dir / 'test')
|
||||
|
||||
# Remove unnecessary columns
|
||||
columns_to_keep = ['input_ids', 'attention_mask']
|
||||
if config['extra_embeddings']:
|
||||
columns_to_keep.extend(['depths', 'sibling_idxs'])
|
||||
dataset = dataset.remove_columns(
|
||||
[col for col in dataset.column_names if col not in columns_to_keep]
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
logger.info('Starting evaluation...')
|
||||
metrics = evaluate_model(
|
||||
model=model,
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
batch_size=config['batch_size'],
|
||||
mlm_probability=config['mlm_probability'],
|
||||
)
|
||||
|
||||
# Log results
|
||||
logger.info('Evaluation Results:')
|
||||
for metric_name, value in metrics.items():
|
||||
logger.info(f'{metric_name}: {value:.4f}')
|
||||
|
||||
# Save results to JSON
|
||||
results_file = results_dir / 'metrics.json'
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(metrics, f, indent=4)
|
||||
logger.info(f'Results saved to {results_file}')
|
||||
|
||||
logger.info('Evaluation completed!')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -9,9 +9,10 @@ from transformers import AutoConfig, BertForMaskedLM
|
||||
from tree_codebert import TreePositionalEmbedding
|
||||
|
||||
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512, log: bool = True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.log = log
|
||||
|
||||
self.fusion_layer = nn.Sequential(
|
||||
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
||||
@ -93,11 +94,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
wandb.log({
|
||||
"embeddings/token": self.alpha.item(),
|
||||
"embeddings/tree": self.beta.item(),
|
||||
"embeddings/seq": self.gamma.item()
|
||||
})
|
||||
if self.log:
|
||||
wandb.log({
|
||||
"embeddings/token": self.alpha.item(),
|
||||
"embeddings/tree": self.beta.item(),
|
||||
"embeddings/seq": self.gamma.item()
|
||||
})
|
||||
|
||||
return {
|
||||
"loss": masked_lm_loss,
|
||||
|
Loading…
Reference in New Issue
Block a user