import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, random_split, DataLoader from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize import matplotlib.pyplot as plt import numpy as np import torchvision.models as models import torch.nn as nn import torch.optim as optim def main(): torch.manual_seed(42) # input_size = 49152 # hidden_sizes = [64, 128] # output_size = 10 classes = os.listdir('./train_dataset') print(classes) mean = [0.6908, 0.6612, 0.6218] std = [0.1947, 0.1926, 0.2086] training_dataset_path = './train_dataset' training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))]) train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True) testing_dataset_path = './test_dataset' testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))]) test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False) # Mean and Standard Deviation approximations def get_mean_and_std(loader): mean = 0. std = 0. total_images_count = 0 for images, _ in loader: image_count_in_a_batch = images.size(0) #print(images.shape) images = images.view(image_count_in_a_batch, images.size(1), -1) #print(images.shape) mean += images.mean(2).sum(0) std += images.std(2).sum(0) total_images_count += image_count_in_a_batch mean /= total_images_count std /= total_images_count return mean, std print(get_mean_and_std(train_loader)) # Show images with applied transformations def show_transformed_images(dataset): loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True) batch = next(iter(loader)) images, labels = batch grid = torchvision.utils.make_grid(images, nrow=3) plt.figure(figsize=(11,11)) plt.imshow(np.transpose(grid, (1,2,0))) print('labels: ', labels) plt.show() show_transformed_images(train_dataset) # Neural network training: def set_device(): if torch.cuda.is_available(): dev = "cuda:0" else: dev = "cpu" return torch.device(dev) def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs): device = set_device() best_acc = 0 for epoch in range(n_epochs): print("Epoch number %d " % (epoch+1)) model.train() running_loss = 0.0 running_correct = 0.0 total = 0 for data in train_loader: images, labels = data images = images.to(device) labels = labels.to(device) total += labels.size(0) # Back propagation optimizer.zero_grad() outputs = model(images) _, predicted = torch.max(outputs.data, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() running_correct += (labels==predicted).sum().item() epoch_loss = running_loss/len(train_loader) epoch_acc = 100.00 * running_correct / total print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss)) test_dataset_acc = evaluate_model_on_test_set(model, test_loader) if(test_dataset_acc > best_acc): best_acc = test_dataset_acc save_checkpoint(model, epoch, optimizer, best_acc) print("Finished") return model def evaluate_model_on_test_set(model, test_loader): model.eval() predicted_correctly_on_epoch = 0 total = 0 device = set_device() with torch.no_grad(): for data in test_loader: images, labels = data images = images.to(device) labels = labels.to(device) total += labels.size(0) outputs = model(images) _, predicted = torch.max(outputs.data, 1) predicted_correctly_on_epoch += (predicted == labels).sum().item() epoch_acc = 100.0 * predicted_correctly_on_epoch / total print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc)) return epoch_acc # Saving the checkpoint: def save_checkpoint(model, epoch, optimizer, best_acc): state = { 'epoch': epoch+1, 'model': model.state_dict(), 'best_accuracy': best_acc, 'optimizer': optimizer.state_dict(), } torch.save(state, 'model_best_checkpoint.zip') resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False num_features = resnet18_model.fc.in_features number_of_classes = 4 resnet18_model.fc = nn.Linear(num_features, number_of_classes) device = set_device() resnet_18_model = resnet18_model.to(device) loss_fn = nn.CrossEntropyLoss() #criterion optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003) train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5) # Saving the model: checkpoint = torch.load('model_best_checkpoint.pth.zip') resnet18_model = models.resnet18() num_features = resnet18_model.fc.in_features number_of_classes = 4 resnet18_model.fc = nn.Linear(num_features, number_of_classes) resnet18_model.load_state_dict(checkpoint['model']) torch.save(resnet18_model, 'garbage_model.pth') if __name__ == "__main__": main()