Compare commits

..

No commits in common. "57c6facea1168211fb3bd0d6b7f647d39f85c407" and "df7c553c59cb75abf9c24082e8c29e3df759eb1a" have entirely different histories.

16 changed files with 43 additions and 52 deletions

View File

@ -4,11 +4,10 @@ from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from NN.model import * from model import *
from PIL import Image from PIL import Image
import pygame
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device('cuda')
#data transform to tensors: #data transform to tensors:
data_transformer = transforms.Compose([ data_transformer = transforms.Compose([
@ -57,12 +56,12 @@ def accuracy(model, dataset):
return correct.float() / len(dataset) return correct.float() / len(dataset)
# model = Conv_Neural_Network_Model() model = Conv_Neural_Network_Model()
# model.to(device) model.to(device)
#loading the already saved model: #loading the already saved model:
# model.load_state_dict(torch.load('CNN_model.pth')) model.load_state_dict(torch.load('CNN_model.pth'))
# model.eval() model.eval()
# #training the model: # #training the model:
# train(model, train_set) # train(model, train_set)
@ -73,27 +72,17 @@ def accuracy(model, dataset):
def load_model(): def load_model():
model = Conv_Neural_Network_Model() model = Conv_Neural_Network_Model()
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu'))) model.load_state_dict(torch.load('CNN_model.pth'))
model.eval() model.eval()
return model return model
def load_image(image_path): def load_image(image_path):
testImage = Image.open(image_path).convert('RGB') testImage = Image.open(image_path)
testImage = data_transformer(testImage) testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0) testImage = testImage.unsqueeze(0)
return testImage return testImage
def display_image(screen, image_path, position):
image = pygame.image.load(image_path)
image = pygame.transform.scale(image, (250, 250))
screen.blit(image, position)
def display_result(screen, position, predicted_class):
font = pygame.font.Font(None, 30)
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
screen.blit(displayed_text, position)
def guess_image(model, image_tensor): def guess_image(model, image_tensor):
with torch.no_grad(): with torch.no_grad():
testOutput = model(image_tensor) testOutput = model(image_tensor)
@ -103,18 +92,23 @@ def guess_image(model, image_tensor):
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
# image_tensor = load_image(image_path)
# prediction = guess_image(load_model(), image_tensor)
# print(f"The predicted image is: {prediction}")
#TEST - loading the image and getting results: #TEST - loading the image and getting results:
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg' testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
# testImage = Image.open(testImage_path) testImage = Image.open(testImage_path)
# testImage = data_transformer(testImage) testImage = data_transformer(testImage)
# testImage = testImage.unsqueeze(0) testImage = testImage.unsqueeze(0)
# testImage = testImage.to(device) testImage = testImage.to(device)
# model.load_state_dict(torch.load('CNN_model.pth')) model.load_state_dict(torch.load('CNN_model.pth'))
# model.to(device) model.to(device)
# model.eval() model.eval()
# testOutput = model(testImage) testOutput = model(testImage)
# _, predicted = torch.max(testOutput, 1) _, predicted = torch.max(testOutput, 1)
# predicted_class = train_set.classes[predicted.item()] predicted_class = train_set.classes[predicted.item()]
# print(f'The predicted class is: {predicted_class}') print(f'The predicted class is: {predicted_class}')

View File

@ -12,12 +12,9 @@ from ground import Dirt
from plant import Plant from plant import Plant
from bfs import graphsearch, Istate, succ from bfs import graphsearch, Istate, succ
from astar import a_star from astar import a_star
from NN.neural_network import load_model, load_image, guess_image, display_image, display_result from NN.neural_network import load_model, load_image, guess_image
from PIL import Image
pygame.init() WIN = pygame.display.set_mode((WIDTH, HEIGHT))
WIN_WIDTH = WIDTH + 300
WIN = pygame.display.set_mode((WIN_WIDTH, HEIGHT))
pygame.display.set_caption('Intelligent tractor') pygame.display.set_caption('Intelligent tractor')
@ -78,15 +75,10 @@ def main():
#guessing the image under the tile: #guessing the image under the tile:
goalTile = tiles[tile_index] goalTile = tiles[tile_index]
goalTile.display_photo()
image_path = goalTile.photo image_path = goalTile.photo
display_image(WIN, goalTile.photo, (WIDTH-20 , 300)) #displays photo next to the field
pygame.display.update()
image_tensor = load_image(image_path) image_tensor = load_image(image_path)
prediction = guess_image(load_model(), image_tensor) prediction = guess_image(load_model(), image_tensor)
display_result(WIN, (WIDTH - 50 , 600), prediction) #display text under the photo
pygame.display.update()
print(f"The predicted image is: {prediction}") print(f"The predicted image is: {prediction}")
@ -151,7 +143,7 @@ def main():
#work on field: #work on field:
if predykcje == 'work': if predykcje == 'work':
tractor.work_on_field(goalTile, d1, p1) tractor.work_on_field(goalTile, d1, p1)
time.sleep(50) time.sleep(30)
print("\n") print("\n")

View File

@ -21,17 +21,17 @@ class Plant:
def update_name(self, predicted_class): def update_name(self, predicted_class):
if predicted_class == "Apple": if predicted_class == "Apple":
self.name = "apple" self.name = "Apple"
self.plant_type = 'fruit' self.plant_type = "fruit"
elif predicted_class == "Radish": elif predicted_class == "Radish":
self.name = "radish" self.name = "Radish"
self.plant_type = 'vegetable' self.plant_type = "vegetable"
elif predicted_class == "Cauliflower": elif predicted_class == "Cauliflower":
self.name = "cauliflower" self.name = "Cauliflower"
self.plant_type = 'vegetable' self.plant_type = "vegetable"
elif predicted_class == "Wheat": elif predicted_class == "Wheat":
self.name = "wheat" self.name = "Wheat"
self.plant_type = 'cereal' self.plant_type = "cereal"

View File

@ -1,5 +1,4 @@
import random import random
import time
import os import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -52,3 +51,9 @@ class Tile:
self.image = "resources/images/rock_dirt.png" self.image = "resources/images/rock_dirt.png"
def display_photo(self):
image_path = self.photo
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()