model trained and saved
This commit is contained in:
parent
50c585eb16
commit
d2ad851cab
BIN
src/veggies_recognition/best_model.pth
Normal file
BIN
src/veggies_recognition/best_model.pth
Normal file
Binary file not shown.
BIN
src/veggies_recognition/model_best_checkpoint.pth.tar
Normal file
BIN
src/veggies_recognition/model_best_checkpoint.pth.tar
Normal file
Binary file not shown.
@ -37,6 +37,8 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
|
||||
def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs):
|
||||
best_acc = 0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print('Epoch number %d' % (epoch+1))
|
||||
model.train()
|
||||
@ -64,10 +66,24 @@ def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs)
|
||||
|
||||
print(' ---Training--- Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f' % (running_correct, total, epoch_acc, epoch_loss))
|
||||
|
||||
evaluate_model_on_test_set(model, test_loader)
|
||||
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
|
||||
|
||||
if(test_dataset_acc>best_acc):
|
||||
best_acc = test_dataset_acc
|
||||
save_checkpoint(model, epoch, optimizer, best_acc)
|
||||
|
||||
print('Finished Training')
|
||||
return model
|
||||
|
||||
def save_checkpoint(model, epoch, optimizer, best_acc):
|
||||
state = {
|
||||
'epoch': epoch +1,
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'best_acc': best_acc
|
||||
}
|
||||
torch.save(state, 'model_best_checkpoint.pth.tar')
|
||||
|
||||
def evaluate_model_on_test_set(model, test_loader):
|
||||
model.eval()
|
||||
predicted_correctly_on_epoch = 0
|
||||
@ -84,6 +100,7 @@ def evaluate_model_on_test_set(model, test_loader):
|
||||
predicted_correctly_on_epoch += (predicted==labels).sum().item()
|
||||
epoch_acc = predicted_correctly_on_epoch / total *100
|
||||
print(' ---Testing--- Got %d out of %d images correctly (%.3f%%).' % (predicted_correctly_on_epoch, total, epoch_acc))
|
||||
return epoch_acc
|
||||
|
||||
|
||||
resnet18_model = models.resnet18(pretrained=False)
|
||||
@ -95,4 +112,14 @@ loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
||||
|
||||
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 20)
|
||||
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 30)
|
||||
|
||||
checkpoint = torch.load('model_best_checkpoint.pth.tar')
|
||||
|
||||
resnet18_model = models.resnet18()
|
||||
num_features = resnet18_model.fc.in_features
|
||||
number_of_classes = 25
|
||||
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||||
resnet18_model.load_state_dict(checkpoint['model'])
|
||||
|
||||
torch.save(resnet18_model, 'best_model.pth')
|
Loading…
Reference in New Issue
Block a user