Male_zoo_Projekt_SI/model/model.py

129 lines
4.1 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
def set_device():
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
return torch.device(device)
train_dataset_path = './data/train'
test_dataset_path = './data/val'
number_of_classes = 7
SIZE = 224
mean = [0.5164, 0.5147, 0.4746]
std = [0.2180, 0.2126, 0.2172]
train_transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
test_transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
resnet18_model = models.resnet18(weights=None)
num_ftrs = resnet18_model.fc.in_features
resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes)
device = set_device()
resnet18_model = resnet18_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003)
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch + 1,
'model': model.state_dict(),
'best accuracy': best_acc,
'optimizer': optimizer.state_dict()
}
torch.save(state, 'model_best_checkpoint.pth.tar')
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
device = set_device()
best_acc = 0
for epoch in range(n_epochs):
print("Epoch number %d " % (epoch + 1))
model.train()
running_loss = 0.0
running_correct = 0.0
total = 0
for data in train_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
optimizer.zero_grad()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_correct += (labels == predicted).sum().item()
epoch_loss = running_loss/len(train_loader)
epoch_acc = 100 * running_correct / total
print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}")
test_data_acc = evaluate_model_on_test_set(model, test_loader)
if test_data_acc > best_acc:
best_acc = test_data_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print("Finished")
return model
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
total = 0
device = set_device()
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predicted_correctly_on_epoch += (predicted == labels).sum().item()
epoch_acc = 100 * predicted_correctly_on_epoch / total
print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})")
return epoch_acc
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30)
checkpoint = torch.load('model_best_checkpoint.pth.tar')
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'best_model.pth')