Inteligentna_smieciarka/model_training/main.py

178 lines
6.1 KiB
Python
Raw Normal View History

2023-06-01 23:44:09 +02:00
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
def main():
torch.manual_seed(42)
# input_size = 49152
# hidden_sizes = [64, 128]
# output_size = 10
classes = os.listdir('./train_dataset')
print(classes)
mean = [0.6908, 0.6612, 0.6218]
std = [0.1947, 0.1926, 0.2086]
training_dataset_path = './train_dataset'
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
testing_dataset_path = './test_dataset'
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
# Mean and Standard Deviation approximations
def get_mean_and_std(loader):
mean = 0.
std = 0.
total_images_count = 0
for images, _ in loader:
image_count_in_a_batch = images.size(0)
#print(images.shape)
images = images.view(image_count_in_a_batch, images.size(1), -1)
#print(images.shape)
mean += images.mean(2).sum(0)
std += images.std(2).sum(0)
total_images_count += image_count_in_a_batch
mean /= total_images_count
std /= total_images_count
return mean, std
print(get_mean_and_std(train_loader))
# Show images with applied transformations
def show_transformed_images(dataset):
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
batch = next(iter(loader))
images, labels = batch
grid = torchvision.utils.make_grid(images, nrow=3)
plt.figure(figsize=(11,11))
plt.imshow(np.transpose(grid, (1,2,0)))
print('labels: ', labels)
plt.show()
show_transformed_images(train_dataset)
# Neural network training:
def set_device():
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
return torch.device(dev)
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
device = set_device()
best_acc = 0
for epoch in range(n_epochs):
print("Epoch number %d " % (epoch+1))
model.train()
running_loss = 0.0
running_correct = 0.0
total = 0
for data in train_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
# Back propagation
optimizer.zero_grad()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_correct += (labels==predicted).sum().item()
epoch_loss = running_loss/len(train_loader)
epoch_acc = 100.00 * running_correct / total
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
if(test_dataset_acc > best_acc):
best_acc = test_dataset_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print("Finished")
return model
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
total = 0
device = set_device()
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predicted_correctly_on_epoch += (predicted == labels).sum().item()
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
return epoch_acc
# Saving the checkpoint:
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch+1,
'model': model.state_dict(),
'best_accuracy': best_acc,
'optimizer': optimizer.state_dict(),
}
torch.save(state, 'model_best_checkpoint.zip')
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
num_features = resnet18_model.fc.in_features
number_of_classes = 4
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
device = set_device()
resnet_18_model = resnet18_model.to(device)
loss_fn = nn.CrossEntropyLoss() #criterion
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
# Saving the model:
checkpoint = torch.load('model_best_checkpoint.pth.zip')
resnet18_model = models.resnet18()
num_features = resnet18_model.fc.in_features
number_of_classes = 4
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'garbage_model.pth')
if __name__ == "__main__":
main()