diff --git a/src/veggies_recognition/best_model.pth b/src/veggies_recognition/best_model.pth new file mode 100644 index 00000000..6fe3a9a2 Binary files /dev/null and b/src/veggies_recognition/best_model.pth differ diff --git a/src/veggies_recognition/model_best_checkpoint.pth.tar b/src/veggies_recognition/model_best_checkpoint.pth.tar new file mode 100644 index 00000000..11236dcd Binary files /dev/null and b/src/veggies_recognition/model_best_checkpoint.pth.tar differ diff --git a/src/veggies_recognition/train.py b/src/veggies_recognition/train.py index da3b1179..849d4b88 100644 --- a/src/veggies_recognition/train.py +++ b/src/veggies_recognition/train.py @@ -37,6 +37,8 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs): + best_acc = 0 + for epoch in range(num_epochs): print('Epoch number %d' % (epoch+1)) model.train() @@ -64,10 +66,24 @@ def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs) print(' ---Training--- Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f' % (running_correct, total, epoch_acc, epoch_loss)) - evaluate_model_on_test_set(model, test_loader) + test_dataset_acc = evaluate_model_on_test_set(model, test_loader) + + if(test_dataset_acc>best_acc): + best_acc = test_dataset_acc + save_checkpoint(model, epoch, optimizer, best_acc) + print('Finished Training') return model +def save_checkpoint(model, epoch, optimizer, best_acc): + state = { + 'epoch': epoch +1, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'best_acc': best_acc + } + torch.save(state, 'model_best_checkpoint.pth.tar') + def evaluate_model_on_test_set(model, test_loader): model.eval() predicted_correctly_on_epoch = 0 @@ -84,6 +100,7 @@ def evaluate_model_on_test_set(model, test_loader): predicted_correctly_on_epoch += (predicted==labels).sum().item() epoch_acc = predicted_correctly_on_epoch / total *100 print(' ---Testing--- Got %d out of %d images correctly (%.3f%%).' % (predicted_correctly_on_epoch, total, epoch_acc)) + return epoch_acc resnet18_model = models.resnet18(pretrained=False) @@ -95,4 +112,14 @@ loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(resnet18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003) -train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 20) \ No newline at end of file +train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 30) + +checkpoint = torch.load('model_best_checkpoint.pth.tar') + +resnet18_model = models.resnet18() +num_features = resnet18_model.fc.in_features +number_of_classes = 25 +resnet18_model.fc = nn.Linear(num_features, number_of_classes) +resnet18_model.load_state_dict(checkpoint['model']) + +torch.save(resnet18_model, 'best_model.pth') \ No newline at end of file