fixed evaluate function
This commit is contained in:
parent
240c16b495
commit
3f116b583d
@ -208,6 +208,7 @@ def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, flo
|
|||||||
|
|
||||||
avg_loss: float = total_loss / len(dataloader)
|
avg_loss: float = total_loss / len(dataloader)
|
||||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
avg_acc: float = total_acc / len(dataloader.dataset)
|
||||||
|
model.train()
|
||||||
return avg_loss, avg_acc
|
return avg_loss, avg_acc
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user