129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torchvision.datasets
|
|
from torchvision import datasets, transforms, models
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def set_device():
|
|
if torch.cuda.is_available():
|
|
device = 'cuda'
|
|
else:
|
|
device = 'cpu'
|
|
return torch.device(device)
|
|
|
|
|
|
train_dataset_path = './data/train'
|
|
test_dataset_path = './data/val'
|
|
number_of_classes = 7
|
|
|
|
SIZE = 224
|
|
mean = [0.5164, 0.5147, 0.4746]
|
|
std = [0.2180, 0.2126, 0.2172]
|
|
|
|
train_transforms = transforms.Compose([
|
|
transforms.Resize((SIZE, SIZE)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomRotation(10),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
|
])
|
|
|
|
test_transforms = transforms.Compose([
|
|
transforms.Resize((SIZE, SIZE)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
|
])
|
|
|
|
train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms)
|
|
test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
|
|
|
|
resnet18_model = models.resnet18(weights=None)
|
|
num_ftrs = resnet18_model.fc.in_features
|
|
resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes)
|
|
device = set_device()
|
|
resnet18_model = resnet18_model.to(device)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
|
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003)
|
|
|
|
|
|
def save_checkpoint(model, epoch, optimizer, best_acc):
|
|
state = {
|
|
'epoch': epoch + 1,
|
|
'model': model.state_dict(),
|
|
'best accuracy': best_acc,
|
|
'optimizer': optimizer.state_dict()
|
|
}
|
|
torch.save(state, 'model_best_checkpoint.pth.tar')
|
|
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
|
|
device = set_device()
|
|
best_acc = 0
|
|
|
|
for epoch in range(n_epochs):
|
|
print("Epoch number %d " % (epoch + 1))
|
|
model.train()
|
|
running_loss = 0.0
|
|
running_correct = 0.0
|
|
total = 0
|
|
|
|
for data in train_loader:
|
|
images, labels = data
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
total += labels.size(0)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
running_correct += (labels == predicted).sum().item()
|
|
|
|
epoch_loss = running_loss/len(train_loader)
|
|
epoch_acc = 100 * running_correct / total
|
|
print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}")
|
|
|
|
test_data_acc = evaluate_model_on_test_set(model, test_loader)
|
|
|
|
if test_data_acc > best_acc:
|
|
best_acc = test_data_acc
|
|
save_checkpoint(model, epoch, optimizer, best_acc)
|
|
|
|
print("Finished")
|
|
return model
|
|
def evaluate_model_on_test_set(model, test_loader):
|
|
model.eval()
|
|
predicted_correctly_on_epoch = 0
|
|
total = 0
|
|
device = set_device()
|
|
|
|
with torch.no_grad():
|
|
for data in test_loader:
|
|
images, labels = data
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
total += labels.size(0)
|
|
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
|
predicted_correctly_on_epoch += (predicted == labels).sum().item()
|
|
|
|
epoch_acc = 100 * predicted_correctly_on_epoch / total
|
|
print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})")
|
|
return epoch_acc
|
|
|
|
|
|
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30)
|
|
|
|
checkpoint = torch.load('model_best_checkpoint.pth.tar')
|
|
resnet18_model.load_state_dict(checkpoint['model'])
|
|
torch.save(resnet18_model, 'best_model.pth') |