fixes
This commit is contained in:
parent
df7c553c59
commit
3d2a88d1ea
Binary file not shown.
Binary file not shown.
@ -4,10 +4,10 @@ from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms, utils
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
import matplotlib.pyplot as plt
|
||||
from model import *
|
||||
from NN.model import *
|
||||
from PIL import Image
|
||||
|
||||
device = torch.device('cuda')
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
#data transform to tensors:
|
||||
data_transformer = transforms.Compose([
|
||||
@ -56,12 +56,12 @@ def accuracy(model, dataset):
|
||||
return correct.float() / len(dataset)
|
||||
|
||||
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.to(device)
|
||||
# model = Conv_Neural_Network_Model()
|
||||
# model.to(device)
|
||||
|
||||
#loading the already saved model:
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.eval()
|
||||
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
# model.eval()
|
||||
|
||||
# #training the model:
|
||||
# train(model, train_set)
|
||||
@ -72,13 +72,13 @@ model.eval()
|
||||
|
||||
def load_model():
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
testImage = Image.open(image_path)
|
||||
testImage = Image.open(image_path).convert('RGB')
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
return testImage
|
||||
@ -98,17 +98,17 @@ def guess_image(model, image_tensor):
|
||||
# print(f"The predicted image is: {prediction}")
|
||||
|
||||
#TEST - loading the image and getting results:
|
||||
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||
testImage = Image.open(testImage_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
testImage = testImage.to(device)
|
||||
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||
# testImage = Image.open(testImage_path)
|
||||
# testImage = data_transformer(testImage)
|
||||
# testImage = testImage.unsqueeze(0)
|
||||
# testImage = testImage.to(device)
|
||||
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
# model.to(device)
|
||||
# model.eval()
|
||||
|
||||
testOutput = model(testImage)
|
||||
_, predicted = torch.max(testOutput, 1)
|
||||
predicted_class = train_set.classes[predicted.item()]
|
||||
print(f'The predicted class is: {predicted_class}')
|
||||
# testOutput = model(testImage)
|
||||
# _, predicted = torch.max(testOutput, 1)
|
||||
# predicted_class = train_set.classes[predicted.item()]
|
||||
# print(f'The predicted class is: {predicted_class}')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -13,6 +13,7 @@ from plant import Plant
|
||||
from bfs import graphsearch, Istate, succ
|
||||
from astar import a_star
|
||||
from NN.neural_network import load_model, load_image, guess_image
|
||||
from PIL import Image
|
||||
|
||||
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||
pygame.display.set_caption('Intelligent tractor')
|
||||
|
@ -21,17 +21,17 @@ class Plant:
|
||||
|
||||
def update_name(self, predicted_class):
|
||||
if predicted_class == "Apple":
|
||||
self.name = "Apple"
|
||||
self.plant_type = "fruit"
|
||||
self.name = "apple"
|
||||
self.plant_type = 'fruit'
|
||||
|
||||
elif predicted_class == "Radish":
|
||||
self.name = "Radish"
|
||||
self.plant_type = "vegetable"
|
||||
self.name = "radish"
|
||||
self.plant_type = 'vegetable'
|
||||
|
||||
elif predicted_class == "Cauliflower":
|
||||
self.name = "Cauliflower"
|
||||
self.plant_type = "vegetable"
|
||||
self.name = "cauliflower"
|
||||
self.plant_type = 'vegetable'
|
||||
|
||||
elif predicted_class == "Wheat":
|
||||
self.name = "Wheat"
|
||||
self.plant_type = "cereal"
|
||||
self.name = "wheat"
|
||||
self.plant_type = 'cereal'
|
@ -1,4 +1,5 @@
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -54,6 +55,9 @@ class Tile:
|
||||
def display_photo(self):
|
||||
image_path = self.photo
|
||||
img = Image.open(image_path)
|
||||
plt.ion()
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
time.sleep(5)
|
||||
plt.close()
|
||||
|
Loading…
Reference in New Issue
Block a user