import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets from torchvision import datasets, transforms, models from torch.utils.data import DataLoader def set_device(): if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' return torch.device(device) train_dataset_path = './data/train' test_dataset_path = './data/val' number_of_classes = 7 SIZE = 224 mean = [0.5164, 0.5147, 0.4746] std = [0.2180, 0.2126, 0.2172] train_transforms = transforms.Compose([ transforms.Resize((SIZE, SIZE)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)) ]) test_transforms = transforms.Compose([ transforms.Resize((SIZE, SIZE)), transforms.ToTensor(), transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)) ]) train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms) test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) resnet18_model = models.resnet18(weights=None) num_ftrs = resnet18_model.fc.in_features resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes) device = set_device() resnet18_model = resnet18_model.to(device) loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003) def save_checkpoint(model, epoch, optimizer, best_acc): state = { 'epoch': epoch + 1, 'model': model.state_dict(), 'best accuracy': best_acc, 'optimizer': optimizer.state_dict() } torch.save(state, 'model_best_checkpoint.pth.tar') def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs): device = set_device() best_acc = 0 for epoch in range(n_epochs): print("Epoch number %d " % (epoch + 1)) model.train() running_loss = 0.0 running_correct = 0.0 total = 0 for data in train_loader: images, labels = data images = images.to(device) labels = labels.to(device) total += labels.size(0) optimizer.zero_grad() outputs = model(images) _, predicted = torch.max(outputs.data, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() running_correct += (labels == predicted).sum().item() epoch_loss = running_loss/len(train_loader) epoch_acc = 100 * running_correct / total print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}") test_data_acc = evaluate_model_on_test_set(model, test_loader) if test_data_acc > best_acc: best_acc = test_data_acc save_checkpoint(model, epoch, optimizer, best_acc) print("Finished") return model def evaluate_model_on_test_set(model, test_loader): model.eval() predicted_correctly_on_epoch = 0 total = 0 device = set_device() with torch.no_grad(): for data in test_loader: images, labels = data images = images.to(device) labels = labels.to(device) total += labels.size(0) outputs = model(images) _, predicted = torch.max(outputs.data, 1) predicted_correctly_on_epoch += (predicted == labels).sum().item() epoch_acc = 100 * predicted_correctly_on_epoch / total print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})") return epoch_acc train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30) checkpoint = torch.load('model_best_checkpoint.pth.tar') resnet18_model.load_state_dict(checkpoint['model']) torch.save(resnet18_model, 'best_model.pth')