Compare commits
2 Commits
df7c553c59
...
57c6facea1
Author | SHA1 | Date | |
---|---|---|---|
|
57c6facea1 | ||
|
3d2a88d1ea |
Binary file not shown.
Binary file not shown.
@ -4,10 +4,11 @@ from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms, utils
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
import matplotlib.pyplot as plt
|
||||
from model import *
|
||||
from NN.model import *
|
||||
from PIL import Image
|
||||
import pygame
|
||||
|
||||
device = torch.device('cuda')
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
#data transform to tensors:
|
||||
data_transformer = transforms.Compose([
|
||||
@ -56,12 +57,12 @@ def accuracy(model, dataset):
|
||||
return correct.float() / len(dataset)
|
||||
|
||||
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.to(device)
|
||||
# model = Conv_Neural_Network_Model()
|
||||
# model.to(device)
|
||||
|
||||
#loading the already saved model:
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.eval()
|
||||
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
# model.eval()
|
||||
|
||||
# #training the model:
|
||||
# train(model, train_set)
|
||||
@ -72,17 +73,27 @@ model.eval()
|
||||
|
||||
def load_model():
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
testImage = Image.open(image_path)
|
||||
testImage = Image.open(image_path).convert('RGB')
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
return testImage
|
||||
|
||||
def display_image(screen, image_path, position):
|
||||
image = pygame.image.load(image_path)
|
||||
image = pygame.transform.scale(image, (250, 250))
|
||||
screen.blit(image, position)
|
||||
|
||||
def display_result(screen, position, predicted_class):
|
||||
font = pygame.font.Font(None, 30)
|
||||
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
|
||||
screen.blit(displayed_text, position)
|
||||
|
||||
def guess_image(model, image_tensor):
|
||||
with torch.no_grad():
|
||||
testOutput = model(image_tensor)
|
||||
@ -92,23 +103,18 @@ def guess_image(model, image_tensor):
|
||||
|
||||
|
||||
|
||||
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
|
||||
# image_tensor = load_image(image_path)
|
||||
# prediction = guess_image(load_model(), image_tensor)
|
||||
# print(f"The predicted image is: {prediction}")
|
||||
|
||||
#TEST - loading the image and getting results:
|
||||
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||
testImage = Image.open(testImage_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
testImage = testImage.to(device)
|
||||
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||
# testImage = Image.open(testImage_path)
|
||||
# testImage = data_transformer(testImage)
|
||||
# testImage = testImage.unsqueeze(0)
|
||||
# testImage = testImage.to(device)
|
||||
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
# model.to(device)
|
||||
# model.eval()
|
||||
|
||||
testOutput = model(testImage)
|
||||
_, predicted = torch.max(testOutput, 1)
|
||||
predicted_class = train_set.classes[predicted.item()]
|
||||
print(f'The predicted class is: {predicted_class}')
|
||||
# testOutput = model(testImage)
|
||||
# _, predicted = torch.max(testOutput, 1)
|
||||
# predicted_class = train_set.classes[predicted.item()]
|
||||
# print(f'The predicted class is: {predicted_class}')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -12,9 +12,12 @@ from ground import Dirt
|
||||
from plant import Plant
|
||||
from bfs import graphsearch, Istate, succ
|
||||
from astar import a_star
|
||||
from NN.neural_network import load_model, load_image, guess_image
|
||||
from NN.neural_network import load_model, load_image, guess_image, display_image, display_result
|
||||
from PIL import Image
|
||||
|
||||
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||
pygame.init()
|
||||
WIN_WIDTH = WIDTH + 300
|
||||
WIN = pygame.display.set_mode((WIN_WIDTH, HEIGHT))
|
||||
pygame.display.set_caption('Intelligent tractor')
|
||||
|
||||
|
||||
@ -75,10 +78,15 @@ def main():
|
||||
|
||||
#guessing the image under the tile:
|
||||
goalTile = tiles[tile_index]
|
||||
goalTile.display_photo()
|
||||
image_path = goalTile.photo
|
||||
display_image(WIN, goalTile.photo, (WIDTH-20 , 300)) #displays photo next to the field
|
||||
pygame.display.update()
|
||||
|
||||
image_tensor = load_image(image_path)
|
||||
prediction = guess_image(load_model(), image_tensor)
|
||||
|
||||
display_result(WIN, (WIDTH - 50 , 600), prediction) #display text under the photo
|
||||
pygame.display.update()
|
||||
print(f"The predicted image is: {prediction}")
|
||||
|
||||
|
||||
@ -143,7 +151,7 @@ def main():
|
||||
#work on field:
|
||||
if predykcje == 'work':
|
||||
tractor.work_on_field(goalTile, d1, p1)
|
||||
time.sleep(30)
|
||||
time.sleep(50)
|
||||
print("\n")
|
||||
|
||||
|
||||
|
@ -21,17 +21,17 @@ class Plant:
|
||||
|
||||
def update_name(self, predicted_class):
|
||||
if predicted_class == "Apple":
|
||||
self.name = "Apple"
|
||||
self.plant_type = "fruit"
|
||||
self.name = "apple"
|
||||
self.plant_type = 'fruit'
|
||||
|
||||
elif predicted_class == "Radish":
|
||||
self.name = "Radish"
|
||||
self.plant_type = "vegetable"
|
||||
self.name = "radish"
|
||||
self.plant_type = 'vegetable'
|
||||
|
||||
elif predicted_class == "Cauliflower":
|
||||
self.name = "Cauliflower"
|
||||
self.plant_type = "vegetable"
|
||||
self.name = "cauliflower"
|
||||
self.plant_type = 'vegetable'
|
||||
|
||||
elif predicted_class == "Wheat":
|
||||
self.name = "Wheat"
|
||||
self.plant_type = "cereal"
|
||||
self.name = "wheat"
|
||||
self.plant_type = 'cereal'
|
@ -1,4 +1,5 @@
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -51,9 +52,3 @@ class Tile:
|
||||
self.image = "resources/images/rock_dirt.png"
|
||||
|
||||
|
||||
def display_photo(self):
|
||||
image_path = self.photo
|
||||
img = Image.open(image_path)
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
Loading…
Reference in New Issue
Block a user