model trained and saved
This commit is contained in:
parent
50c585eb16
commit
d2ad851cab
BIN
src/veggies_recognition/best_model.pth
Normal file
BIN
src/veggies_recognition/best_model.pth
Normal file
Binary file not shown.
BIN
src/veggies_recognition/model_best_checkpoint.pth.tar
Normal file
BIN
src/veggies_recognition/model_best_checkpoint.pth.tar
Normal file
Binary file not shown.
@ -37,6 +37,8 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|||||||
print(device)
|
print(device)
|
||||||
|
|
||||||
def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs):
|
def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs):
|
||||||
|
best_acc = 0
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
print('Epoch number %d' % (epoch+1))
|
print('Epoch number %d' % (epoch+1))
|
||||||
model.train()
|
model.train()
|
||||||
@ -64,10 +66,24 @@ def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs)
|
|||||||
|
|
||||||
print(' ---Training--- Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f' % (running_correct, total, epoch_acc, epoch_loss))
|
print(' ---Training--- Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f' % (running_correct, total, epoch_acc, epoch_loss))
|
||||||
|
|
||||||
evaluate_model_on_test_set(model, test_loader)
|
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
|
||||||
|
|
||||||
|
if(test_dataset_acc>best_acc):
|
||||||
|
best_acc = test_dataset_acc
|
||||||
|
save_checkpoint(model, epoch, optimizer, best_acc)
|
||||||
|
|
||||||
print('Finished Training')
|
print('Finished Training')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def save_checkpoint(model, epoch, optimizer, best_acc):
|
||||||
|
state = {
|
||||||
|
'epoch': epoch +1,
|
||||||
|
'model': model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'best_acc': best_acc
|
||||||
|
}
|
||||||
|
torch.save(state, 'model_best_checkpoint.pth.tar')
|
||||||
|
|
||||||
def evaluate_model_on_test_set(model, test_loader):
|
def evaluate_model_on_test_set(model, test_loader):
|
||||||
model.eval()
|
model.eval()
|
||||||
predicted_correctly_on_epoch = 0
|
predicted_correctly_on_epoch = 0
|
||||||
@ -84,6 +100,7 @@ def evaluate_model_on_test_set(model, test_loader):
|
|||||||
predicted_correctly_on_epoch += (predicted==labels).sum().item()
|
predicted_correctly_on_epoch += (predicted==labels).sum().item()
|
||||||
epoch_acc = predicted_correctly_on_epoch / total *100
|
epoch_acc = predicted_correctly_on_epoch / total *100
|
||||||
print(' ---Testing--- Got %d out of %d images correctly (%.3f%%).' % (predicted_correctly_on_epoch, total, epoch_acc))
|
print(' ---Testing--- Got %d out of %d images correctly (%.3f%%).' % (predicted_correctly_on_epoch, total, epoch_acc))
|
||||||
|
return epoch_acc
|
||||||
|
|
||||||
|
|
||||||
resnet18_model = models.resnet18(pretrained=False)
|
resnet18_model = models.resnet18(pretrained=False)
|
||||||
@ -95,4 +112,14 @@ loss_fn = nn.CrossEntropyLoss()
|
|||||||
|
|
||||||
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
||||||
|
|
||||||
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 20)
|
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 30)
|
||||||
|
|
||||||
|
checkpoint = torch.load('model_best_checkpoint.pth.tar')
|
||||||
|
|
||||||
|
resnet18_model = models.resnet18()
|
||||||
|
num_features = resnet18_model.fc.in_features
|
||||||
|
number_of_classes = 25
|
||||||
|
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||||||
|
resnet18_model.load_state_dict(checkpoint['model'])
|
||||||
|
|
||||||
|
torch.save(resnet18_model, 'best_model.pth')
|
Loading…
Reference in New Issue
Block a user