fixed evaluate function
This commit is contained in:
parent
240c16b495
commit
3f116b583d
@ -208,6 +208,7 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo
|
||||
|
||||
avg_loss: float = total_loss / len(dataloader)
|
||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
||||
model.train()
|
||||
return avg_loss, avg_acc
|
||||
|
||||
def main() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user