model trained and saved

This commit is contained in:
Hubert Westerlich 2024-05-26 02:49:54 +02:00
parent 50c585eb16
commit d2ad851cab
3 changed files with 29 additions and 2 deletions

Binary file not shown.

Binary file not shown.

View File

@ -37,6 +37,8 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs):
best_acc = 0
for epoch in range(num_epochs):
print('Epoch number %d' % (epoch+1))
model.train()
@ -64,10 +66,24 @@ def train_nn(model, train_loader, test_loader, criterion, optimizer, num_epochs)
print(' ---Training--- Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f' % (running_correct, total, epoch_acc, epoch_loss))
evaluate_model_on_test_set(model, test_loader)
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
if(test_dataset_acc>best_acc):
best_acc = test_dataset_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print('Finished Training')
return model
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch +1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_acc': best_acc
}
torch.save(state, 'model_best_checkpoint.pth.tar')
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
@ -84,6 +100,7 @@ def evaluate_model_on_test_set(model, test_loader):
predicted_correctly_on_epoch += (predicted==labels).sum().item()
epoch_acc = predicted_correctly_on_epoch / total *100
print(' ---Testing--- Got %d out of %d images correctly (%.3f%%).' % (predicted_correctly_on_epoch, total, epoch_acc))
return epoch_acc
resnet18_model = models.resnet18(pretrained=False)
@ -95,4 +112,14 @@ loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 20)
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, 30)
checkpoint = torch.load('model_best_checkpoint.pth.tar')
resnet18_model = models.resnet18()
num_features = resnet18_model.fc.in_features
number_of_classes = 25
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'best_model.pth')