From 50292376e7356e30ddbec3b1258d795fc4deb75b Mon Sep 17 00:00:00 2001 From: Aliaksei Brown Date: Mon, 5 Jun 2023 01:30:39 +0200 Subject: [PATCH] added: trained_model option --- neural_network/nn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/neural_network/nn.py b/neural_network/nn.py index a6a0ca4e..0423165c 100644 --- a/neural_network/nn.py +++ b/neural_network/nn.py @@ -49,6 +49,9 @@ def main(): # Train the model def train_model(model, criterion, optimizer, num_epochs=2): + best_model_wts = None # Initialize the variable + best_acc = 0.0 + for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") print("-" * 10) @@ -85,7 +88,7 @@ def main(): print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}") - if phase == "val" and epoch_acc > best_acc: + if phase == "validation" and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = model.state_dict()