From 27369bc561ee3c4b5bfe0cd21e8f966bda424b36 Mon Sep 17 00:00:00 2001 From: Patryk Bartkowiak Date: Wed, 18 Dec 2024 19:48:06 +0000 Subject: [PATCH] tree-starencoder working (linear layer instead of alpha beta gamma) --- code/pyproject.toml | 1 + code/src/eval_config.json | 9 ++ code/src/eval_model.py | 188 +++++++++++++++++++++++++++++++++++ code/src/tree_starencoder.py | 14 +-- 4 files changed, 206 insertions(+), 6 deletions(-) create mode 100644 code/src/eval_config.json create mode 100644 code/src/eval_model.py diff --git a/code/pyproject.toml b/code/pyproject.toml index 293bf8a..e7bbeef 100644 --- a/code/pyproject.toml +++ b/code/pyproject.toml @@ -40,3 +40,4 @@ distribution = true [tool.pdm.scripts] parse_dataset = {cmd = "src/parse_dataset.py"} train = {cmd = "src/training.py"} +eval = {cmd = "src/eval_model.py"} diff --git a/code/src/eval_config.json b/code/src/eval_config.json new file mode 100644 index 0000000..646e06b --- /dev/null +++ b/code/src/eval_config.json @@ -0,0 +1,9 @@ +{ + "extra_embeddings": false, + "data_dir": "./data/CodeSearchNet-parsed-starencoder/python", + "model_dir": "./outputs/original-starencoder", + "seed": 420, + "mlm_probability": 0.15, + "batch_size": 32, + "fp16": true +} \ No newline at end of file diff --git a/code/src/eval_model.py b/code/src/eval_model.py new file mode 100644 index 0000000..dbdc822 --- /dev/null +++ b/code/src/eval_model.py @@ -0,0 +1,188 @@ +import torch +import json +import logging +import numpy as np +from pathlib import Path +from datasets import load_from_disk +from safetensors.torch import load_file +from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, DataCollatorForLanguageModeling +from tqdm import tqdm +from typing import Dict + +from tree_starencoder import TreeStarEncoderForPreTraining + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + +def load_config(config_path: Path) -> dict: + with open(config_path, 'r') as f: + return json.load(f) + +def compute_metrics(predictions: torch.Tensor, labels: torch.Tensor, mask_positions: torch.Tensor) -> Dict: + """Compute MLM metrics including accuracy and top-k accuracy.""" + # Get predictions only for masked tokens + masked_predictions = predictions[mask_positions] # Shape: [num_masked_tokens, vocab_size] + masked_labels = labels[mask_positions] # Shape: [num_masked_tokens] + + # Calculate top-k accuracy using raw logits (before softmax) + top1_acc = (masked_predictions.argmax(dim=-1) == masked_labels).float().mean().item() + top5_acc = (masked_predictions.topk(k=5, dim=-1).indices == masked_labels.unsqueeze(-1)).any(dim=-1).float().mean().item() + + # Calculate per-token loss + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + # Don't apply softmax here - CrossEntropyLoss expects raw logits + token_losses = loss_fct(masked_predictions, masked_labels) + + return { + 'top1_accuracy': top1_acc * 100, # Convert to percentage + 'top5_accuracy': top5_acc * 100, # Convert to percentage + 'mean_token_loss': token_losses.mean().item(), + 'max_token_loss': token_losses.max().item(), + 'min_token_loss': token_losses.min().item(), + 'std_token_loss': token_losses.std().item(), + } + +def evaluate_model(model, dataset, tokenizer, device, batch_size=8, mlm_probability=0.15): + model.eval() + + # Set seed for reproducible masking + torch.manual_seed(42) + np.random.seed(42) + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=True, + mlm_probability=mlm_probability, + ) + + all_metrics = [] + total_loss = 0 + total_samples = 0 + + # Create a DataLoader that applies masking on the fly + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + collate_fn=data_collator, + generator=torch.Generator().manual_seed(42) + ) + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(dataloader, desc="Evaluating")): + inputs = {k: v.to(device) for k, v in batch.items()} + + outputs = model(**inputs) + loss = outputs['loss'] + logits = outputs['logits'] + + masked_positions = inputs['input_ids'] == tokenizer.mask_token_id + masked_count = masked_positions.sum().item() + + # Add debugging for first batch + if batch_idx == 0: + logger.info(f"Number of masked tokens in first batch: {masked_count}") + logger.info(f"Logits shape: {logits.shape}") + logger.info(f"Sample logits min/max: {logits[0,0].min().item():.4f}/{logits[0,0].max().item():.4f}") + + if masked_count == 0: + continue + + if torch.isnan(logits).any() or torch.isinf(logits).any(): + logger.warning(f"Found NaN or Inf in logits in batch {batch_idx}!") + continue + + batch_metrics = compute_metrics(logits, inputs['labels'], masked_positions) + all_metrics.append(batch_metrics) + + total_loss += loss.item() * inputs['input_ids'].size(0) + total_samples += inputs['input_ids'].size(0) + + # Calculate average metrics + avg_metrics = { + k: np.mean([m[k] for m in all_metrics]) + for k in all_metrics[0].keys() + } + + # Calculate perplexity + avg_loss = total_loss / total_samples + perplexity = torch.exp(torch.tensor(avg_loss)).item() + + avg_metrics['perplexity'] = perplexity + avg_metrics['loss'] = avg_loss + + return avg_metrics + +def main(): + # Setup paths + current_dir = Path(__file__).parent + config = load_config(current_dir / 'eval_config.json') + model_dir = Path(config['model_dir']) / 'final-model' + data_dir = Path(config['data_dir']) + results_dir = Path(config['model_dir']) / 'evaluation_results' + results_dir.mkdir(exist_ok=True) + + # Load model and tokenizer + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Initialize model from config + model_config = AutoConfig.from_pretrained(model_dir / 'config.json') + model_config.max_position_embeddings = 1024 + + if config['extra_embeddings']: + model = TreeStarEncoderForPreTraining(config=model_config, log=False) + else: + model = AutoModelForMaskedLM.from_config(model_config) + + # Load weights from safetensors + state_dict = load_file(model_dir / 'model.safetensors') + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + # Re-tie the word embeddings and decoder: + model.tie_weights() + + model = model.to(device) + tokenizer = AutoTokenizer.from_pretrained(model_dir) + + # Load dataset + dataset = load_from_disk(data_dir / 'test') + + # Remove unnecessary columns + columns_to_keep = ['input_ids', 'attention_mask'] + if config['extra_embeddings']: + columns_to_keep.extend(['depths', 'sibling_idxs']) + dataset = dataset.remove_columns( + [col for col in dataset.column_names if col not in columns_to_keep] + ) + + # Evaluate + logger.info('Starting evaluation...') + metrics = evaluate_model( + model=model, + dataset=dataset, + tokenizer=tokenizer, + device=device, + batch_size=config['batch_size'], + mlm_probability=config['mlm_probability'], + ) + + # Log results + logger.info('Evaluation Results:') + for metric_name, value in metrics.items(): + logger.info(f'{metric_name}: {value:.4f}') + + # Save results to JSON + results_file = results_dir / 'metrics.json' + with open(results_file, 'w') as f: + json.dump(metrics, f, indent=4) + logger.info(f'Results saved to {results_file}') + + logger.info('Evaluation completed!') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/code/src/tree_starencoder.py b/code/src/tree_starencoder.py index b20a2bf..7aaaeda 100644 --- a/code/src/tree_starencoder.py +++ b/code/src/tree_starencoder.py @@ -9,9 +9,10 @@ from transformers import AutoConfig, BertForMaskedLM from tree_codebert import TreePositionalEmbedding class TreeStarEncoderForPreTraining(BertForMaskedLM): - def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512): + def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512, log: bool = True): super().__init__(config) self.config = config + self.log = log self.fusion_layer = nn.Sequential( nn.Linear(config.hidden_size * 3, config.hidden_size), @@ -93,11 +94,12 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM): labels.view(-1) ) - wandb.log({ - "embeddings/token": self.alpha.item(), - "embeddings/tree": self.beta.item(), - "embeddings/seq": self.gamma.item() - }) + if self.log: + wandb.log({ + "embeddings/token": self.alpha.item(), + "embeddings/tree": self.beta.item(), + "embeddings/seq": self.gamma.item() + }) return { "loss": masked_lm_loss,