fixed evaluate function

This commit is contained in:
Patryk Bartkowiak 2024-10-21 20:12:20 +00:00
parent 240c16b495
commit 3f116b583d

View File

@ -208,6 +208,7 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo
avg_loss: float = total_loss / len(dataloader)
avg_acc: float = total_acc / len(dataloader.dataset)
model.train()
return avg_loss, avg_acc
def main() -> None: