Male_zoo_Projekt_SI/model/model.py

129 lines
4.1 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
def set_device():
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
return torch.device(device)
train_dataset_path = './data/train'
test_dataset_path = './data/val'
number_of_classes = 7
SIZE = 224
mean = [0.5164, 0.5147, 0.4746]
std = [0.2180, 0.2126, 0.2172]
train_transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
test_transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
resnet18_model = models.resnet18(weights=None)
num_ftrs = resnet18_model.fc.in_features
resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes)
device = set_device()
resnet18_model = resnet18_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003)
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch + 1,
'model': model.state_dict(),
'best accuracy': best_acc,
'optimizer': optimizer.state_dict()
}
torch.save(state, 'model_best_checkpoint.pth.tar')
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
device = set_device()
best_acc = 0
for epoch in range(n_epochs):
print("Epoch number %d " % (epoch + 1))
model.train()
running_loss = 0.0
running_correct = 0.0
total = 0
for data in train_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
optimizer.zero_grad()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_correct += (labels == predicted).sum().item()
epoch_loss = running_loss/len(train_loader)
epoch_acc = 100 * running_correct / total
print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}")
test_data_acc = evaluate_model_on_test_set(model, test_loader)
if test_data_acc > best_acc:
best_acc = test_data_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print("Finished")
return model
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
total = 0
device = set_device()
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predicted_correctly_on_epoch += (predicted == labels).sum().item()
epoch_acc = 100 * predicted_correctly_on_epoch / total
print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})")
return epoch_acc
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30)
checkpoint = torch.load('model_best_checkpoint.pth.tar')
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'best_model.pth')