Compare commits

...

2 Commits

Author SHA1 Message Date
Marek
57c6facea1 added photo display next to the field 2024-06-04 16:55:27 +02:00
Marek
3d2a88d1ea fixes 2024-05-27 05:28:48 +02:00
16 changed files with 52 additions and 43 deletions

View File

@ -4,10 +4,11 @@ from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt
from model import *
from NN.model import *
from PIL import Image
import pygame
device = torch.device('cuda')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#data transform to tensors:
data_transformer = transforms.Compose([
@ -56,12 +57,12 @@ def accuracy(model, dataset):
return correct.float() / len(dataset)
model = Conv_Neural_Network_Model()
model.to(device)
# model = Conv_Neural_Network_Model()
# model.to(device)
#loading the already saved model:
model.load_state_dict(torch.load('CNN_model.pth'))
model.eval()
# model.load_state_dict(torch.load('CNN_model.pth'))
# model.eval()
# #training the model:
# train(model, train_set)
@ -72,17 +73,27 @@ model.eval()
def load_model():
model = Conv_Neural_Network_Model()
model.load_state_dict(torch.load('CNN_model.pth'))
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
model.eval()
return model
def load_image(image_path):
testImage = Image.open(image_path)
testImage = Image.open(image_path).convert('RGB')
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
return testImage
def display_image(screen, image_path, position):
image = pygame.image.load(image_path)
image = pygame.transform.scale(image, (250, 250))
screen.blit(image, position)
def display_result(screen, position, predicted_class):
font = pygame.font.Font(None, 30)
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
screen.blit(displayed_text, position)
def guess_image(model, image_tensor):
with torch.no_grad():
testOutput = model(image_tensor)
@ -92,23 +103,18 @@ def guess_image(model, image_tensor):
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
# image_tensor = load_image(image_path)
# prediction = guess_image(load_model(), image_tensor)
# print(f"The predicted image is: {prediction}")
#TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
testImage = Image.open(testImage_path)
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
testImage = testImage.to(device)
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
# testImage = Image.open(testImage_path)
# testImage = data_transformer(testImage)
# testImage = testImage.unsqueeze(0)
# testImage = testImage.to(device)
model.load_state_dict(torch.load('CNN_model.pth'))
model.to(device)
model.eval()
# model.load_state_dict(torch.load('CNN_model.pth'))
# model.to(device)
# model.eval()
testOutput = model(testImage)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
print(f'The predicted class is: {predicted_class}')
# testOutput = model(testImage)
# _, predicted = torch.max(testOutput, 1)
# predicted_class = train_set.classes[predicted.item()]
# print(f'The predicted class is: {predicted_class}')

View File

@ -12,9 +12,12 @@ from ground import Dirt
from plant import Plant
from bfs import graphsearch, Istate, succ
from astar import a_star
from NN.neural_network import load_model, load_image, guess_image
from NN.neural_network import load_model, load_image, guess_image, display_image, display_result
from PIL import Image
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.init()
WIN_WIDTH = WIDTH + 300
WIN = pygame.display.set_mode((WIN_WIDTH, HEIGHT))
pygame.display.set_caption('Intelligent tractor')
@ -75,10 +78,15 @@ def main():
#guessing the image under the tile:
goalTile = tiles[tile_index]
goalTile.display_photo()
image_path = goalTile.photo
display_image(WIN, goalTile.photo, (WIDTH-20 , 300)) #displays photo next to the field
pygame.display.update()
image_tensor = load_image(image_path)
prediction = guess_image(load_model(), image_tensor)
display_result(WIN, (WIDTH - 50 , 600), prediction) #display text under the photo
pygame.display.update()
print(f"The predicted image is: {prediction}")
@ -143,7 +151,7 @@ def main():
#work on field:
if predykcje == 'work':
tractor.work_on_field(goalTile, d1, p1)
time.sleep(30)
time.sleep(50)
print("\n")

View File

@ -21,17 +21,17 @@ class Plant:
def update_name(self, predicted_class):
if predicted_class == "Apple":
self.name = "Apple"
self.plant_type = "fruit"
self.name = "apple"
self.plant_type = 'fruit'
elif predicted_class == "Radish":
self.name = "Radish"
self.plant_type = "vegetable"
self.name = "radish"
self.plant_type = 'vegetable'
elif predicted_class == "Cauliflower":
self.name = "Cauliflower"
self.plant_type = "vegetable"
self.name = "cauliflower"
self.plant_type = 'vegetable'
elif predicted_class == "Wheat":
self.name = "Wheat"
self.plant_type = "cereal"
self.name = "wheat"
self.plant_type = 'cereal'

View File

@ -1,4 +1,5 @@
import random
import time
import os
import numpy as np
import matplotlib.pyplot as plt
@ -51,9 +52,3 @@ class Tile:
self.image = "resources/images/rock_dirt.png"
def display_photo(self):
image_path = self.photo
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()