tree-starencoder working (linear layer instead of alpha beta gamma)
This commit is contained in:
parent
4006b1fdfd
commit
27369bc561
@ -40,3 +40,4 @@ distribution = true
|
|||||||
[tool.pdm.scripts]
|
[tool.pdm.scripts]
|
||||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||||
train = {cmd = "src/training.py"}
|
train = {cmd = "src/training.py"}
|
||||||
|
eval = {cmd = "src/eval_model.py"}
|
||||||
|
9
code/src/eval_config.json
Normal file
9
code/src/eval_config.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"extra_embeddings": false,
|
||||||
|
"data_dir": "./data/CodeSearchNet-parsed-starencoder/python",
|
||||||
|
"model_dir": "./outputs/original-starencoder",
|
||||||
|
"seed": 420,
|
||||||
|
"mlm_probability": 0.15,
|
||||||
|
"batch_size": 32,
|
||||||
|
"fp16": true
|
||||||
|
}
|
188
code/src/eval_model.py
Normal file
188
code/src/eval_model.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from datasets import load_from_disk
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, DataCollatorForLanguageModeling
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_config(config_path: Path) -> dict:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def compute_metrics(predictions: torch.Tensor, labels: torch.Tensor, mask_positions: torch.Tensor) -> Dict:
|
||||||
|
"""Compute MLM metrics including accuracy and top-k accuracy."""
|
||||||
|
# Get predictions only for masked tokens
|
||||||
|
masked_predictions = predictions[mask_positions] # Shape: [num_masked_tokens, vocab_size]
|
||||||
|
masked_labels = labels[mask_positions] # Shape: [num_masked_tokens]
|
||||||
|
|
||||||
|
# Calculate top-k accuracy using raw logits (before softmax)
|
||||||
|
top1_acc = (masked_predictions.argmax(dim=-1) == masked_labels).float().mean().item()
|
||||||
|
top5_acc = (masked_predictions.topk(k=5, dim=-1).indices == masked_labels.unsqueeze(-1)).any(dim=-1).float().mean().item()
|
||||||
|
|
||||||
|
# Calculate per-token loss
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
|
# Don't apply softmax here - CrossEntropyLoss expects raw logits
|
||||||
|
token_losses = loss_fct(masked_predictions, masked_labels)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'top1_accuracy': top1_acc * 100, # Convert to percentage
|
||||||
|
'top5_accuracy': top5_acc * 100, # Convert to percentage
|
||||||
|
'mean_token_loss': token_losses.mean().item(),
|
||||||
|
'max_token_loss': token_losses.max().item(),
|
||||||
|
'min_token_loss': token_losses.min().item(),
|
||||||
|
'std_token_loss': token_losses.std().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def evaluate_model(model, dataset, tokenizer, device, batch_size=8, mlm_probability=0.15):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Set seed for reproducible masking
|
||||||
|
torch.manual_seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
mlm=True,
|
||||||
|
mlm_probability=mlm_probability,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_metrics = []
|
||||||
|
total_loss = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
# Create a DataLoader that applies masking on the fly
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
generator=torch.Generator().manual_seed(42)
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
|
||||||
|
inputs = {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
loss = outputs['loss']
|
||||||
|
logits = outputs['logits']
|
||||||
|
|
||||||
|
masked_positions = inputs['input_ids'] == tokenizer.mask_token_id
|
||||||
|
masked_count = masked_positions.sum().item()
|
||||||
|
|
||||||
|
# Add debugging for first batch
|
||||||
|
if batch_idx == 0:
|
||||||
|
logger.info(f"Number of masked tokens in first batch: {masked_count}")
|
||||||
|
logger.info(f"Logits shape: {logits.shape}")
|
||||||
|
logger.info(f"Sample logits min/max: {logits[0,0].min().item():.4f}/{logits[0,0].max().item():.4f}")
|
||||||
|
|
||||||
|
if masked_count == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if torch.isnan(logits).any() or torch.isinf(logits).any():
|
||||||
|
logger.warning(f"Found NaN or Inf in logits in batch {batch_idx}!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
batch_metrics = compute_metrics(logits, inputs['labels'], masked_positions)
|
||||||
|
all_metrics.append(batch_metrics)
|
||||||
|
|
||||||
|
total_loss += loss.item() * inputs['input_ids'].size(0)
|
||||||
|
total_samples += inputs['input_ids'].size(0)
|
||||||
|
|
||||||
|
# Calculate average metrics
|
||||||
|
avg_metrics = {
|
||||||
|
k: np.mean([m[k] for m in all_metrics])
|
||||||
|
for k in all_metrics[0].keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate perplexity
|
||||||
|
avg_loss = total_loss / total_samples
|
||||||
|
perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
||||||
|
|
||||||
|
avg_metrics['perplexity'] = perplexity
|
||||||
|
avg_metrics['loss'] = avg_loss
|
||||||
|
|
||||||
|
return avg_metrics
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Setup paths
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
config = load_config(current_dir / 'eval_config.json')
|
||||||
|
model_dir = Path(config['model_dir']) / 'final-model'
|
||||||
|
data_dir = Path(config['data_dir'])
|
||||||
|
results_dir = Path(config['model_dir']) / 'evaluation_results'
|
||||||
|
results_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Initialize model from config
|
||||||
|
model_config = AutoConfig.from_pretrained(model_dir / 'config.json')
|
||||||
|
model_config.max_position_embeddings = 1024
|
||||||
|
|
||||||
|
if config['extra_embeddings']:
|
||||||
|
model = TreeStarEncoderForPreTraining(config=model_config, log=False)
|
||||||
|
else:
|
||||||
|
model = AutoModelForMaskedLM.from_config(model_config)
|
||||||
|
|
||||||
|
# Load weights from safetensors
|
||||||
|
state_dict = load_file(model_dir / 'model.safetensors')
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print("Missing keys:", missing_keys)
|
||||||
|
print("Unexpected keys:", unexpected_keys)
|
||||||
|
|
||||||
|
# Re-tie the word embeddings and decoder:
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
dataset = load_from_disk(data_dir / 'test')
|
||||||
|
|
||||||
|
# Remove unnecessary columns
|
||||||
|
columns_to_keep = ['input_ids', 'attention_mask']
|
||||||
|
if config['extra_embeddings']:
|
||||||
|
columns_to_keep.extend(['depths', 'sibling_idxs'])
|
||||||
|
dataset = dataset.remove_columns(
|
||||||
|
[col for col in dataset.column_names if col not in columns_to_keep]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
logger.info('Starting evaluation...')
|
||||||
|
metrics = evaluate_model(
|
||||||
|
model=model,
|
||||||
|
dataset=dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=device,
|
||||||
|
batch_size=config['batch_size'],
|
||||||
|
mlm_probability=config['mlm_probability'],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
logger.info('Evaluation Results:')
|
||||||
|
for metric_name, value in metrics.items():
|
||||||
|
logger.info(f'{metric_name}: {value:.4f}')
|
||||||
|
|
||||||
|
# Save results to JSON
|
||||||
|
results_file = results_dir / 'metrics.json'
|
||||||
|
with open(results_file, 'w') as f:
|
||||||
|
json.dump(metrics, f, indent=4)
|
||||||
|
logger.info(f'Results saved to {results_file}')
|
||||||
|
|
||||||
|
logger.info('Evaluation completed!')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -9,9 +9,10 @@ from transformers import AutoConfig, BertForMaskedLM
|
|||||||
from tree_codebert import TreePositionalEmbedding
|
from tree_codebert import TreePositionalEmbedding
|
||||||
|
|
||||||
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512, log: bool = True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.log = log
|
||||||
|
|
||||||
self.fusion_layer = nn.Sequential(
|
self.fusion_layer = nn.Sequential(
|
||||||
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
nn.Linear(config.hidden_size * 3, config.hidden_size),
|
||||||
@ -93,11 +94,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
labels.view(-1)
|
labels.view(-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
wandb.log({
|
if self.log:
|
||||||
"embeddings/token": self.alpha.item(),
|
wandb.log({
|
||||||
"embeddings/tree": self.beta.item(),
|
"embeddings/token": self.alpha.item(),
|
||||||
"embeddings/seq": self.gamma.item()
|
"embeddings/tree": self.beta.item(),
|
||||||
})
|
"embeddings/seq": self.gamma.item()
|
||||||
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loss": masked_lm_loss,
|
"loss": masked_lm_loss,
|
||||||
|
Loading…
Reference in New Issue
Block a user