diff --git a/code/src/train_codebert_mlm.py b/code/src/train_codebert_mlm.py index ae84076..41b1912 100644 --- a/code/src/train_codebert_mlm.py +++ b/code/src/train_codebert_mlm.py @@ -208,6 +208,7 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo avg_loss: float = total_loss / len(dataloader) avg_acc: float = total_acc / len(dataloader.dataset) + model.train() return avg_loss, avg_acc def main() -> None: