Traktor/source/NN/neural_network.py

114 lines
3.4 KiB
Python
Raw Normal View History

2024-05-25 02:07:27 +02:00
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
2024-05-25 18:41:25 +02:00
from torchvision import datasets, transforms, utils
2024-05-25 02:07:27 +02:00
from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt
from model import *
2024-05-25 22:30:04 +02:00
from PIL import Image
2024-05-25 02:07:27 +02:00
2024-05-25 18:41:25 +02:00
device = torch.device('cuda')
2024-05-25 02:07:27 +02:00
#data transform to tensors:
2024-05-25 18:41:25 +02:00
data_transformer = transforms.Compose([
transforms.Resize((100, 100)),
2024-05-25 02:07:27 +02:00
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5))
2024-05-25 02:07:27 +02:00
])
#loading data:
train_set = datasets.ImageFolder(root='resources/train', transform=data_transformer)
test_set = datasets.ImageFolder(root='resources/test', transform=data_transformer)
2024-05-25 02:07:27 +02:00
#to mozna nawet przerzucic do funkcji train:
# train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
2024-05-25 02:07:27 +02:00
#function for training model
def train(model, dataset, iter=100, batch_size=64):
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model.train()
for epoch in range(iter):
for inputs, labels in train_loader:
optimizer.zero_grad()
output = model(inputs.to(device))
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
2024-05-25 18:41:25 +02:00
if epoch % 10 == 0:
print('epoch: %3d loss: %.4f' % (epoch, loss))
#function for getting accuracy
def accuracy(model, dataset):
model.eval()
2024-05-25 22:30:04 +02:00
with torch.no_grad():
correct = sum([
(model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum()
for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True)
])
return correct.float() / len(dataset)
2024-05-25 02:07:27 +02:00
model = Conv_Neural_Network_Model()
2024-05-25 18:41:25 +02:00
model.to(device)
2024-05-25 22:30:04 +02:00
#loading the already saved model:
# model.load_state_dict(torch.load('model.pth'))
# model.eval()
2024-05-25 22:30:04 +02:00
#training the model:
# train(model, train_set)
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
# torch.save(model.state_dict(), 'model.pth')
def load_model():
model = Conv_Neural_Network_Model()
model.load_state_dict(torch.load('CNN_model.pth'))
model.eval()
return model
def load_image(image_path):
testImage = Image.open(image_path)
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
return testImage
def guess_image(model, image_tensor):
with torch.no_grad():
testOutput = model(image_tensor)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
return predicted_class
2024-05-25 22:30:04 +02:00
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
# image_tensor = load_image(image_path)
# prediction = guess_image(load_model(), image_tensor)
# print(f"The predicted image is: {prediction}")
#TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg'
testImage = Image.open(testImage_path)
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
testImage = testImage.to(device)
model.load_state_dict(torch.load('CNN_model.pth'))
model.to(device)
model.eval()
testOutput = model(testImage)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
print(f'The predicted class is: {predicted_class}')