178 lines
6.1 KiB
Python
178 lines
6.1 KiB
Python
|
import os
|
||
|
import torch
|
||
|
import torchvision
|
||
|
import torchvision.transforms as transforms
|
||
|
from torch.utils.data import Dataset, random_split, DataLoader
|
||
|
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import torchvision.models as models
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
|
||
|
def main():
|
||
|
torch.manual_seed(42)
|
||
|
# input_size = 49152
|
||
|
# hidden_sizes = [64, 128]
|
||
|
# output_size = 10
|
||
|
|
||
|
classes = os.listdir('./train_dataset')
|
||
|
print(classes)
|
||
|
mean = [0.6908, 0.6612, 0.6218]
|
||
|
std = [0.1947, 0.1926, 0.2086]
|
||
|
|
||
|
training_dataset_path = './train_dataset'
|
||
|
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
||
|
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
|
||
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
|
||
|
|
||
|
testing_dataset_path = './test_dataset'
|
||
|
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
||
|
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
|
||
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
|
||
|
|
||
|
# Mean and Standard Deviation approximations
|
||
|
def get_mean_and_std(loader):
|
||
|
mean = 0.
|
||
|
std = 0.
|
||
|
total_images_count = 0
|
||
|
for images, _ in loader:
|
||
|
image_count_in_a_batch = images.size(0)
|
||
|
#print(images.shape)
|
||
|
images = images.view(image_count_in_a_batch, images.size(1), -1)
|
||
|
#print(images.shape)
|
||
|
mean += images.mean(2).sum(0)
|
||
|
std += images.std(2).sum(0)
|
||
|
total_images_count += image_count_in_a_batch
|
||
|
mean /= total_images_count
|
||
|
std /= total_images_count
|
||
|
return mean, std
|
||
|
|
||
|
print(get_mean_and_std(train_loader))
|
||
|
|
||
|
# Show images with applied transformations
|
||
|
def show_transformed_images(dataset):
|
||
|
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
|
||
|
batch = next(iter(loader))
|
||
|
images, labels = batch
|
||
|
|
||
|
grid = torchvision.utils.make_grid(images, nrow=3)
|
||
|
plt.figure(figsize=(11,11))
|
||
|
plt.imshow(np.transpose(grid, (1,2,0)))
|
||
|
print('labels: ', labels)
|
||
|
plt.show()
|
||
|
|
||
|
show_transformed_images(train_dataset)
|
||
|
|
||
|
# Neural network training:
|
||
|
def set_device():
|
||
|
if torch.cuda.is_available():
|
||
|
dev = "cuda:0"
|
||
|
else:
|
||
|
dev = "cpu"
|
||
|
return torch.device(dev)
|
||
|
|
||
|
|
||
|
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
|
||
|
device = set_device()
|
||
|
best_acc = 0
|
||
|
|
||
|
for epoch in range(n_epochs):
|
||
|
print("Epoch number %d " % (epoch+1))
|
||
|
model.train()
|
||
|
running_loss = 0.0
|
||
|
running_correct = 0.0
|
||
|
total = 0
|
||
|
|
||
|
for data in train_loader:
|
||
|
images, labels = data
|
||
|
images = images.to(device)
|
||
|
labels = labels.to(device)
|
||
|
total += labels.size(0)
|
||
|
|
||
|
# Back propagation
|
||
|
optimizer.zero_grad()
|
||
|
outputs = model(images)
|
||
|
_, predicted = torch.max(outputs.data, 1)
|
||
|
loss = criterion(outputs, labels)
|
||
|
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
running_loss += loss.item()
|
||
|
running_correct += (labels==predicted).sum().item()
|
||
|
|
||
|
epoch_loss = running_loss/len(train_loader)
|
||
|
epoch_acc = 100.00 * running_correct / total
|
||
|
|
||
|
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
|
||
|
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
|
||
|
|
||
|
if(test_dataset_acc > best_acc):
|
||
|
best_acc = test_dataset_acc
|
||
|
save_checkpoint(model, epoch, optimizer, best_acc)
|
||
|
|
||
|
print("Finished")
|
||
|
return model
|
||
|
|
||
|
|
||
|
def evaluate_model_on_test_set(model, test_loader):
|
||
|
model.eval()
|
||
|
predicted_correctly_on_epoch = 0
|
||
|
total = 0
|
||
|
device = set_device()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
for data in test_loader:
|
||
|
images, labels = data
|
||
|
images = images.to(device)
|
||
|
labels = labels.to(device)
|
||
|
total += labels.size(0)
|
||
|
|
||
|
outputs = model(images)
|
||
|
_, predicted = torch.max(outputs.data, 1)
|
||
|
predicted_correctly_on_epoch += (predicted == labels).sum().item()
|
||
|
|
||
|
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
|
||
|
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
|
||
|
|
||
|
return epoch_acc
|
||
|
|
||
|
|
||
|
# Saving the checkpoint:
|
||
|
def save_checkpoint(model, epoch, optimizer, best_acc):
|
||
|
state = {
|
||
|
'epoch': epoch+1,
|
||
|
'model': model.state_dict(),
|
||
|
'best_accuracy': best_acc,
|
||
|
'optimizer': optimizer.state_dict(),
|
||
|
}
|
||
|
torch.save(state, 'model_best_checkpoint.zip')
|
||
|
|
||
|
|
||
|
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
|
||
|
num_features = resnet18_model.fc.in_features
|
||
|
number_of_classes = 4
|
||
|
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||
|
device = set_device()
|
||
|
resnet_18_model = resnet18_model.to(device)
|
||
|
loss_fn = nn.CrossEntropyLoss() #criterion
|
||
|
|
||
|
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
||
|
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
|
||
|
|
||
|
|
||
|
# Saving the model:
|
||
|
checkpoint = torch.load('model_best_checkpoint.pth.zip')
|
||
|
|
||
|
resnet18_model = models.resnet18()
|
||
|
num_features = resnet18_model.fc.in_features
|
||
|
number_of_classes = 4
|
||
|
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||
|
resnet18_model.load_state_dict(checkpoint['model'])
|
||
|
|
||
|
torch.save(resnet18_model, 'garbage_model.pth')
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|