Compare commits

...

2 Commits

Author SHA1 Message Date
Marek
57c6facea1 added photo display next to the field 2024-06-04 16:55:27 +02:00
Marek
3d2a88d1ea fixes 2024-05-27 05:28:48 +02:00
16 changed files with 52 additions and 43 deletions

View File

@ -4,10 +4,11 @@ from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from model import * from NN.model import *
from PIL import Image from PIL import Image
import pygame
device = torch.device('cuda') device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#data transform to tensors: #data transform to tensors:
data_transformer = transforms.Compose([ data_transformer = transforms.Compose([
@ -56,12 +57,12 @@ def accuracy(model, dataset):
return correct.float() / len(dataset) return correct.float() / len(dataset)
model = Conv_Neural_Network_Model() # model = Conv_Neural_Network_Model()
model.to(device) # model.to(device)
#loading the already saved model: #loading the already saved model:
model.load_state_dict(torch.load('CNN_model.pth')) # model.load_state_dict(torch.load('CNN_model.pth'))
model.eval() # model.eval()
# #training the model: # #training the model:
# train(model, train_set) # train(model, train_set)
@ -72,17 +73,27 @@ model.eval()
def load_model(): def load_model():
model = Conv_Neural_Network_Model() model = Conv_Neural_Network_Model()
model.load_state_dict(torch.load('CNN_model.pth')) model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
model.eval() model.eval()
return model return model
def load_image(image_path): def load_image(image_path):
testImage = Image.open(image_path) testImage = Image.open(image_path).convert('RGB')
testImage = data_transformer(testImage) testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0) testImage = testImage.unsqueeze(0)
return testImage return testImage
def display_image(screen, image_path, position):
image = pygame.image.load(image_path)
image = pygame.transform.scale(image, (250, 250))
screen.blit(image, position)
def display_result(screen, position, predicted_class):
font = pygame.font.Font(None, 30)
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
screen.blit(displayed_text, position)
def guess_image(model, image_tensor): def guess_image(model, image_tensor):
with torch.no_grad(): with torch.no_grad():
testOutput = model(image_tensor) testOutput = model(image_tensor)
@ -92,23 +103,18 @@ def guess_image(model, image_tensor):
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
# image_tensor = load_image(image_path)
# prediction = guess_image(load_model(), image_tensor)
# print(f"The predicted image is: {prediction}")
#TEST - loading the image and getting results: #TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg' # testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
testImage = Image.open(testImage_path) # testImage = Image.open(testImage_path)
testImage = data_transformer(testImage) # testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0) # testImage = testImage.unsqueeze(0)
testImage = testImage.to(device) # testImage = testImage.to(device)
model.load_state_dict(torch.load('CNN_model.pth')) # model.load_state_dict(torch.load('CNN_model.pth'))
model.to(device) # model.to(device)
model.eval() # model.eval()
testOutput = model(testImage) # testOutput = model(testImage)
_, predicted = torch.max(testOutput, 1) # _, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()] # predicted_class = train_set.classes[predicted.item()]
print(f'The predicted class is: {predicted_class}') # print(f'The predicted class is: {predicted_class}')

View File

@ -12,9 +12,12 @@ from ground import Dirt
from plant import Plant from plant import Plant
from bfs import graphsearch, Istate, succ from bfs import graphsearch, Istate, succ
from astar import a_star from astar import a_star
from NN.neural_network import load_model, load_image, guess_image from NN.neural_network import load_model, load_image, guess_image, display_image, display_result
from PIL import Image
WIN = pygame.display.set_mode((WIDTH, HEIGHT)) pygame.init()
WIN_WIDTH = WIDTH + 300
WIN = pygame.display.set_mode((WIN_WIDTH, HEIGHT))
pygame.display.set_caption('Intelligent tractor') pygame.display.set_caption('Intelligent tractor')
@ -75,10 +78,15 @@ def main():
#guessing the image under the tile: #guessing the image under the tile:
goalTile = tiles[tile_index] goalTile = tiles[tile_index]
goalTile.display_photo()
image_path = goalTile.photo image_path = goalTile.photo
display_image(WIN, goalTile.photo, (WIDTH-20 , 300)) #displays photo next to the field
pygame.display.update()
image_tensor = load_image(image_path) image_tensor = load_image(image_path)
prediction = guess_image(load_model(), image_tensor) prediction = guess_image(load_model(), image_tensor)
display_result(WIN, (WIDTH - 50 , 600), prediction) #display text under the photo
pygame.display.update()
print(f"The predicted image is: {prediction}") print(f"The predicted image is: {prediction}")
@ -143,7 +151,7 @@ def main():
#work on field: #work on field:
if predykcje == 'work': if predykcje == 'work':
tractor.work_on_field(goalTile, d1, p1) tractor.work_on_field(goalTile, d1, p1)
time.sleep(30) time.sleep(50)
print("\n") print("\n")

View File

@ -21,17 +21,17 @@ class Plant:
def update_name(self, predicted_class): def update_name(self, predicted_class):
if predicted_class == "Apple": if predicted_class == "Apple":
self.name = "Apple" self.name = "apple"
self.plant_type = "fruit" self.plant_type = 'fruit'
elif predicted_class == "Radish": elif predicted_class == "Radish":
self.name = "Radish" self.name = "radish"
self.plant_type = "vegetable" self.plant_type = 'vegetable'
elif predicted_class == "Cauliflower": elif predicted_class == "Cauliflower":
self.name = "Cauliflower" self.name = "cauliflower"
self.plant_type = "vegetable" self.plant_type = 'vegetable'
elif predicted_class == "Wheat": elif predicted_class == "Wheat":
self.name = "Wheat" self.name = "wheat"
self.plant_type = "cereal" self.plant_type = 'cereal'

View File

@ -1,4 +1,5 @@
import random import random
import time
import os import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -51,9 +52,3 @@ class Tile:
self.image = "resources/images/rock_dirt.png" self.image = "resources/images/rock_dirt.png"
def display_photo(self):
image_path = self.photo
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()