Compare commits
2 Commits
df7c553c59
...
57c6facea1
Author | SHA1 | Date | |
---|---|---|---|
|
57c6facea1 | ||
|
3d2a88d1ea |
Binary file not shown.
Binary file not shown.
@ -4,10 +4,11 @@ from torch.utils.data import DataLoader
|
|||||||
from torchvision import datasets, transforms, utils
|
from torchvision import datasets, transforms, utils
|
||||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from model import *
|
from NN.model import *
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import pygame
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
#data transform to tensors:
|
#data transform to tensors:
|
||||||
data_transformer = transforms.Compose([
|
data_transformer = transforms.Compose([
|
||||||
@ -56,12 +57,12 @@ def accuracy(model, dataset):
|
|||||||
return correct.float() / len(dataset)
|
return correct.float() / len(dataset)
|
||||||
|
|
||||||
|
|
||||||
model = Conv_Neural_Network_Model()
|
# model = Conv_Neural_Network_Model()
|
||||||
model.to(device)
|
# model.to(device)
|
||||||
|
|
||||||
#loading the already saved model:
|
#loading the already saved model:
|
||||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||||
model.eval()
|
# model.eval()
|
||||||
|
|
||||||
# #training the model:
|
# #training the model:
|
||||||
# train(model, train_set)
|
# train(model, train_set)
|
||||||
@ -72,17 +73,27 @@ model.eval()
|
|||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
model = Conv_Neural_Network_Model()
|
model = Conv_Neural_Network_Model()
|
||||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_path):
|
def load_image(image_path):
|
||||||
testImage = Image.open(image_path)
|
testImage = Image.open(image_path).convert('RGB')
|
||||||
testImage = data_transformer(testImage)
|
testImage = data_transformer(testImage)
|
||||||
testImage = testImage.unsqueeze(0)
|
testImage = testImage.unsqueeze(0)
|
||||||
return testImage
|
return testImage
|
||||||
|
|
||||||
|
def display_image(screen, image_path, position):
|
||||||
|
image = pygame.image.load(image_path)
|
||||||
|
image = pygame.transform.scale(image, (250, 250))
|
||||||
|
screen.blit(image, position)
|
||||||
|
|
||||||
|
def display_result(screen, position, predicted_class):
|
||||||
|
font = pygame.font.Font(None, 30)
|
||||||
|
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
|
||||||
|
screen.blit(displayed_text, position)
|
||||||
|
|
||||||
def guess_image(model, image_tensor):
|
def guess_image(model, image_tensor):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
testOutput = model(image_tensor)
|
testOutput = model(image_tensor)
|
||||||
@ -92,23 +103,18 @@ def guess_image(model, image_tensor):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
|
|
||||||
# image_tensor = load_image(image_path)
|
|
||||||
# prediction = guess_image(load_model(), image_tensor)
|
|
||||||
# print(f"The predicted image is: {prediction}")
|
|
||||||
|
|
||||||
#TEST - loading the image and getting results:
|
#TEST - loading the image and getting results:
|
||||||
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||||
testImage = Image.open(testImage_path)
|
# testImage = Image.open(testImage_path)
|
||||||
testImage = data_transformer(testImage)
|
# testImage = data_transformer(testImage)
|
||||||
testImage = testImage.unsqueeze(0)
|
# testImage = testImage.unsqueeze(0)
|
||||||
testImage = testImage.to(device)
|
# testImage = testImage.to(device)
|
||||||
|
|
||||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||||
model.to(device)
|
# model.to(device)
|
||||||
model.eval()
|
# model.eval()
|
||||||
|
|
||||||
testOutput = model(testImage)
|
# testOutput = model(testImage)
|
||||||
_, predicted = torch.max(testOutput, 1)
|
# _, predicted = torch.max(testOutput, 1)
|
||||||
predicted_class = train_set.classes[predicted.item()]
|
# predicted_class = train_set.classes[predicted.item()]
|
||||||
print(f'The predicted class is: {predicted_class}')
|
# print(f'The predicted class is: {predicted_class}')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -12,9 +12,12 @@ from ground import Dirt
|
|||||||
from plant import Plant
|
from plant import Plant
|
||||||
from bfs import graphsearch, Istate, succ
|
from bfs import graphsearch, Istate, succ
|
||||||
from astar import a_star
|
from astar import a_star
|
||||||
from NN.neural_network import load_model, load_image, guess_image
|
from NN.neural_network import load_model, load_image, guess_image, display_image, display_result
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
pygame.init()
|
||||||
|
WIN_WIDTH = WIDTH + 300
|
||||||
|
WIN = pygame.display.set_mode((WIN_WIDTH, HEIGHT))
|
||||||
pygame.display.set_caption('Intelligent tractor')
|
pygame.display.set_caption('Intelligent tractor')
|
||||||
|
|
||||||
|
|
||||||
@ -75,10 +78,15 @@ def main():
|
|||||||
|
|
||||||
#guessing the image under the tile:
|
#guessing the image under the tile:
|
||||||
goalTile = tiles[tile_index]
|
goalTile = tiles[tile_index]
|
||||||
goalTile.display_photo()
|
|
||||||
image_path = goalTile.photo
|
image_path = goalTile.photo
|
||||||
|
display_image(WIN, goalTile.photo, (WIDTH-20 , 300)) #displays photo next to the field
|
||||||
|
pygame.display.update()
|
||||||
|
|
||||||
image_tensor = load_image(image_path)
|
image_tensor = load_image(image_path)
|
||||||
prediction = guess_image(load_model(), image_tensor)
|
prediction = guess_image(load_model(), image_tensor)
|
||||||
|
|
||||||
|
display_result(WIN, (WIDTH - 50 , 600), prediction) #display text under the photo
|
||||||
|
pygame.display.update()
|
||||||
print(f"The predicted image is: {prediction}")
|
print(f"The predicted image is: {prediction}")
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +151,7 @@ def main():
|
|||||||
#work on field:
|
#work on field:
|
||||||
if predykcje == 'work':
|
if predykcje == 'work':
|
||||||
tractor.work_on_field(goalTile, d1, p1)
|
tractor.work_on_field(goalTile, d1, p1)
|
||||||
time.sleep(30)
|
time.sleep(50)
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,17 +21,17 @@ class Plant:
|
|||||||
|
|
||||||
def update_name(self, predicted_class):
|
def update_name(self, predicted_class):
|
||||||
if predicted_class == "Apple":
|
if predicted_class == "Apple":
|
||||||
self.name = "Apple"
|
self.name = "apple"
|
||||||
self.plant_type = "fruit"
|
self.plant_type = 'fruit'
|
||||||
|
|
||||||
elif predicted_class == "Radish":
|
elif predicted_class == "Radish":
|
||||||
self.name = "Radish"
|
self.name = "radish"
|
||||||
self.plant_type = "vegetable"
|
self.plant_type = 'vegetable'
|
||||||
|
|
||||||
elif predicted_class == "Cauliflower":
|
elif predicted_class == "Cauliflower":
|
||||||
self.name = "Cauliflower"
|
self.name = "cauliflower"
|
||||||
self.plant_type = "vegetable"
|
self.plant_type = 'vegetable'
|
||||||
|
|
||||||
elif predicted_class == "Wheat":
|
elif predicted_class == "Wheat":
|
||||||
self.name = "Wheat"
|
self.name = "wheat"
|
||||||
self.plant_type = "cereal"
|
self.plant_type = 'cereal'
|
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import time
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -51,9 +52,3 @@ class Tile:
|
|||||||
self.image = "resources/images/rock_dirt.png"
|
self.image = "resources/images/rock_dirt.png"
|
||||||
|
|
||||||
|
|
||||||
def display_photo(self):
|
|
||||||
image_path = self.photo
|
|
||||||
img = Image.open(image_path)
|
|
||||||
plt.imshow(img)
|
|
||||||
plt.axis('off')
|
|
||||||
plt.show()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user