added: trained_model option
This commit is contained in:
parent
ffe1ff9565
commit
50292376e7
@ -49,6 +49,9 @@ def main():
|
||||
|
||||
# Train the model
|
||||
def train_model(model, criterion, optimizer, num_epochs=2):
|
||||
best_model_wts = None # Initialize the variable
|
||||
best_acc = 0.0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 10)
|
||||
@ -85,7 +88,7 @@ def main():
|
||||
|
||||
print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
|
||||
|
||||
if phase == "val" and epoch_acc > best_acc:
|
||||
if phase == "validation" and epoch_acc > best_acc:
|
||||
best_acc = epoch_acc
|
||||
best_model_wts = model.state_dict()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user