added: trained_model option

This commit is contained in:
Aliaksei Brown 2023-06-05 01:30:39 +02:00
parent ffe1ff9565
commit 50292376e7

View File

@ -49,6 +49,9 @@ def main():
# Train the model
def train_model(model, criterion, optimizer, num_epochs=2):
best_model_wts = None # Initialize the variable
best_acc = 0.0
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
print("-" * 10)
@ -85,7 +88,7 @@ def main():
print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
if phase == "val" and epoch_acc > best_acc:
if phase == "validation" and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = model.state_dict()