diff --git a/neural_network/nn.py b/neural_network/nn.py index a6a0ca4e..0423165c 100644 --- a/neural_network/nn.py +++ b/neural_network/nn.py @@ -49,6 +49,9 @@ def main(): # Train the model def train_model(model, criterion, optimizer, num_epochs=2): + best_model_wts = None # Initialize the variable + best_acc = 0.0 + for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") print("-" * 10) @@ -85,7 +88,7 @@ def main(): print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}") - if phase == "val" and epoch_acc > best_acc: + if phase == "validation" and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = model.state_dict()