From 3f116b583dbe74c9c21bdb8edbc3697ef1cd8413 Mon Sep 17 00:00:00 2001 From: Patryk Bartkowiak Date: Mon, 21 Oct 2024 20:12:20 +0000 Subject: [PATCH] fixed evaluate function --- code/src/train_codebert_mlm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/code/src/train_codebert_mlm.py b/code/src/train_codebert_mlm.py index ae84076..41b1912 100644 --- a/code/src/train_codebert_mlm.py +++ b/code/src/train_codebert_mlm.py @@ -208,6 +208,7 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo avg_loss: float = total_loss / len(dataloader) avg_acc: float = total_acc / len(dataloader.dataset) + model.train() return avg_loss, avg_acc def main() -> None: