This commit is contained in:
Marek 2024-05-27 05:28:48 +02:00
parent df7c553c59
commit 3d2a88d1ea
16 changed files with 33 additions and 28 deletions

View File

@ -4,10 +4,10 @@ from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from model import * from NN.model import *
from PIL import Image from PIL import Image
device = torch.device('cuda') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#data transform to tensors: #data transform to tensors:
data_transformer = transforms.Compose([ data_transformer = transforms.Compose([
@ -56,12 +56,12 @@ def accuracy(model, dataset):
return correct.float() / len(dataset) return correct.float() / len(dataset)
model = Conv_Neural_Network_Model() # model = Conv_Neural_Network_Model()
model.to(device) # model.to(device)
#loading the already saved model: #loading the already saved model:
model.load_state_dict(torch.load('CNN_model.pth')) # model.load_state_dict(torch.load('CNN_model.pth'))
model.eval() # model.eval()
# #training the model: # #training the model:
# train(model, train_set) # train(model, train_set)
@ -72,13 +72,13 @@ model.eval()
def load_model(): def load_model():
model = Conv_Neural_Network_Model() model = Conv_Neural_Network_Model()
model.load_state_dict(torch.load('CNN_model.pth')) model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
model.eval() model.eval()
return model return model
def load_image(image_path): def load_image(image_path):
testImage = Image.open(image_path) testImage = Image.open(image_path).convert('RGB')
testImage = data_transformer(testImage) testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0) testImage = testImage.unsqueeze(0)
return testImage return testImage
@ -98,17 +98,17 @@ def guess_image(model, image_tensor):
# print(f"The predicted image is: {prediction}") # print(f"The predicted image is: {prediction}")
#TEST - loading the image and getting results: #TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg' # testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
testImage = Image.open(testImage_path) # testImage = Image.open(testImage_path)
testImage = data_transformer(testImage) # testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0) # testImage = testImage.unsqueeze(0)
testImage = testImage.to(device) # testImage = testImage.to(device)
model.load_state_dict(torch.load('CNN_model.pth')) # model.load_state_dict(torch.load('CNN_model.pth'))
model.to(device) # model.to(device)
model.eval() # model.eval()
testOutput = model(testImage) # testOutput = model(testImage)
_, predicted = torch.max(testOutput, 1) # _, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()] # predicted_class = train_set.classes[predicted.item()]
print(f'The predicted class is: {predicted_class}') # print(f'The predicted class is: {predicted_class}')

View File

@ -13,6 +13,7 @@ from plant import Plant
from bfs import graphsearch, Istate, succ from bfs import graphsearch, Istate, succ
from astar import a_star from astar import a_star
from NN.neural_network import load_model, load_image, guess_image from NN.neural_network import load_model, load_image, guess_image
from PIL import Image
WIN = pygame.display.set_mode((WIDTH, HEIGHT)) WIN = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption('Intelligent tractor') pygame.display.set_caption('Intelligent tractor')

View File

@ -21,17 +21,17 @@ class Plant:
def update_name(self, predicted_class): def update_name(self, predicted_class):
if predicted_class == "Apple": if predicted_class == "Apple":
self.name = "Apple" self.name = "apple"
self.plant_type = "fruit" self.plant_type = 'fruit'
elif predicted_class == "Radish": elif predicted_class == "Radish":
self.name = "Radish" self.name = "radish"
self.plant_type = "vegetable" self.plant_type = 'vegetable'
elif predicted_class == "Cauliflower": elif predicted_class == "Cauliflower":
self.name = "Cauliflower" self.name = "cauliflower"
self.plant_type = "vegetable" self.plant_type = 'vegetable'
elif predicted_class == "Wheat": elif predicted_class == "Wheat":
self.name = "Wheat" self.name = "wheat"
self.plant_type = "cereal" self.plant_type = 'cereal'

View File

@ -1,4 +1,5 @@
import random import random
import time
import os import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -54,6 +55,9 @@ class Tile:
def display_photo(self): def display_photo(self):
image_path = self.photo image_path = self.photo
img = Image.open(image_path) img = Image.open(image_path)
plt.ion()
plt.imshow(img) plt.imshow(img)
plt.axis('off') plt.axis('off')
plt.show() plt.show()
time.sleep(5)
plt.close()