Compare commits
No commits in common. "57c6facea1168211fb3bd0d6b7f647d39f85c407" and "df7c553c59cb75abf9c24082e8c29e3df759eb1a" have entirely different histories.
57c6facea1
...
df7c553c59
Binary file not shown.
Binary file not shown.
@ -4,11 +4,10 @@ from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms, utils
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
import matplotlib.pyplot as plt
|
||||
from NN.model import *
|
||||
from model import *
|
||||
from PIL import Image
|
||||
import pygame
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
device = torch.device('cuda')
|
||||
|
||||
#data transform to tensors:
|
||||
data_transformer = transforms.Compose([
|
||||
@ -57,12 +56,12 @@ def accuracy(model, dataset):
|
||||
return correct.float() / len(dataset)
|
||||
|
||||
|
||||
# model = Conv_Neural_Network_Model()
|
||||
# model.to(device)
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.to(device)
|
||||
|
||||
#loading the already saved model:
|
||||
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
# model.eval()
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.eval()
|
||||
|
||||
# #training the model:
|
||||
# train(model, train_set)
|
||||
@ -73,27 +72,17 @@ def accuracy(model, dataset):
|
||||
|
||||
def load_model():
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
testImage = Image.open(image_path).convert('RGB')
|
||||
testImage = Image.open(image_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
return testImage
|
||||
|
||||
def display_image(screen, image_path, position):
|
||||
image = pygame.image.load(image_path)
|
||||
image = pygame.transform.scale(image, (250, 250))
|
||||
screen.blit(image, position)
|
||||
|
||||
def display_result(screen, position, predicted_class):
|
||||
font = pygame.font.Font(None, 30)
|
||||
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
|
||||
screen.blit(displayed_text, position)
|
||||
|
||||
def guess_image(model, image_tensor):
|
||||
with torch.no_grad():
|
||||
testOutput = model(image_tensor)
|
||||
@ -103,18 +92,23 @@ def guess_image(model, image_tensor):
|
||||
|
||||
|
||||
|
||||
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
|
||||
# image_tensor = load_image(image_path)
|
||||
# prediction = guess_image(load_model(), image_tensor)
|
||||
# print(f"The predicted image is: {prediction}")
|
||||
|
||||
#TEST - loading the image and getting results:
|
||||
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||
# testImage = Image.open(testImage_path)
|
||||
# testImage = data_transformer(testImage)
|
||||
# testImage = testImage.unsqueeze(0)
|
||||
# testImage = testImage.to(device)
|
||||
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||
testImage = Image.open(testImage_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
testImage = testImage.to(device)
|
||||
|
||||
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
# model.to(device)
|
||||
# model.eval()
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# testOutput = model(testImage)
|
||||
# _, predicted = torch.max(testOutput, 1)
|
||||
# predicted_class = train_set.classes[predicted.item()]
|
||||
# print(f'The predicted class is: {predicted_class}')
|
||||
testOutput = model(testImage)
|
||||
_, predicted = torch.max(testOutput, 1)
|
||||
predicted_class = train_set.classes[predicted.item()]
|
||||
print(f'The predicted class is: {predicted_class}')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -12,12 +12,9 @@ from ground import Dirt
|
||||
from plant import Plant
|
||||
from bfs import graphsearch, Istate, succ
|
||||
from astar import a_star
|
||||
from NN.neural_network import load_model, load_image, guess_image, display_image, display_result
|
||||
from PIL import Image
|
||||
from NN.neural_network import load_model, load_image, guess_image
|
||||
|
||||
pygame.init()
|
||||
WIN_WIDTH = WIDTH + 300
|
||||
WIN = pygame.display.set_mode((WIN_WIDTH, HEIGHT))
|
||||
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||
pygame.display.set_caption('Intelligent tractor')
|
||||
|
||||
|
||||
@ -78,15 +75,10 @@ def main():
|
||||
|
||||
#guessing the image under the tile:
|
||||
goalTile = tiles[tile_index]
|
||||
goalTile.display_photo()
|
||||
image_path = goalTile.photo
|
||||
display_image(WIN, goalTile.photo, (WIDTH-20 , 300)) #displays photo next to the field
|
||||
pygame.display.update()
|
||||
|
||||
image_tensor = load_image(image_path)
|
||||
prediction = guess_image(load_model(), image_tensor)
|
||||
|
||||
display_result(WIN, (WIDTH - 50 , 600), prediction) #display text under the photo
|
||||
pygame.display.update()
|
||||
print(f"The predicted image is: {prediction}")
|
||||
|
||||
|
||||
@ -151,7 +143,7 @@ def main():
|
||||
#work on field:
|
||||
if predykcje == 'work':
|
||||
tractor.work_on_field(goalTile, d1, p1)
|
||||
time.sleep(50)
|
||||
time.sleep(30)
|
||||
print("\n")
|
||||
|
||||
|
||||
|
@ -21,17 +21,17 @@ class Plant:
|
||||
|
||||
def update_name(self, predicted_class):
|
||||
if predicted_class == "Apple":
|
||||
self.name = "apple"
|
||||
self.plant_type = 'fruit'
|
||||
self.name = "Apple"
|
||||
self.plant_type = "fruit"
|
||||
|
||||
elif predicted_class == "Radish":
|
||||
self.name = "radish"
|
||||
self.plant_type = 'vegetable'
|
||||
self.name = "Radish"
|
||||
self.plant_type = "vegetable"
|
||||
|
||||
elif predicted_class == "Cauliflower":
|
||||
self.name = "cauliflower"
|
||||
self.plant_type = 'vegetable'
|
||||
self.name = "Cauliflower"
|
||||
self.plant_type = "vegetable"
|
||||
|
||||
elif predicted_class == "Wheat":
|
||||
self.name = "wheat"
|
||||
self.plant_type = 'cereal'
|
||||
self.name = "Wheat"
|
||||
self.plant_type = "cereal"
|
@ -1,5 +1,4 @@
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -52,3 +51,9 @@ class Tile:
|
||||
self.image = "resources/images/rock_dirt.png"
|
||||
|
||||
|
||||
def display_photo(self):
|
||||
image_path = self.photo
|
||||
img = Image.open(image_path)
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
Loading…
Reference in New Issue
Block a user