fixes
This commit is contained in:
parent
df7c553c59
commit
3d2a88d1ea
Binary file not shown.
Binary file not shown.
@ -4,10 +4,10 @@ from torch.utils.data import DataLoader
|
|||||||
from torchvision import datasets, transforms, utils
|
from torchvision import datasets, transforms, utils
|
||||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from model import *
|
from NN.model import *
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
#data transform to tensors:
|
#data transform to tensors:
|
||||||
data_transformer = transforms.Compose([
|
data_transformer = transforms.Compose([
|
||||||
@ -56,12 +56,12 @@ def accuracy(model, dataset):
|
|||||||
return correct.float() / len(dataset)
|
return correct.float() / len(dataset)
|
||||||
|
|
||||||
|
|
||||||
model = Conv_Neural_Network_Model()
|
# model = Conv_Neural_Network_Model()
|
||||||
model.to(device)
|
# model.to(device)
|
||||||
|
|
||||||
#loading the already saved model:
|
#loading the already saved model:
|
||||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||||
model.eval()
|
# model.eval()
|
||||||
|
|
||||||
# #training the model:
|
# #training the model:
|
||||||
# train(model, train_set)
|
# train(model, train_set)
|
||||||
@ -72,13 +72,13 @@ model.eval()
|
|||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
model = Conv_Neural_Network_Model()
|
model = Conv_Neural_Network_Model()
|
||||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_path):
|
def load_image(image_path):
|
||||||
testImage = Image.open(image_path)
|
testImage = Image.open(image_path).convert('RGB')
|
||||||
testImage = data_transformer(testImage)
|
testImage = data_transformer(testImage)
|
||||||
testImage = testImage.unsqueeze(0)
|
testImage = testImage.unsqueeze(0)
|
||||||
return testImage
|
return testImage
|
||||||
@ -98,17 +98,17 @@ def guess_image(model, image_tensor):
|
|||||||
# print(f"The predicted image is: {prediction}")
|
# print(f"The predicted image is: {prediction}")
|
||||||
|
|
||||||
#TEST - loading the image and getting results:
|
#TEST - loading the image and getting results:
|
||||||
testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
||||||
testImage = Image.open(testImage_path)
|
# testImage = Image.open(testImage_path)
|
||||||
testImage = data_transformer(testImage)
|
# testImage = data_transformer(testImage)
|
||||||
testImage = testImage.unsqueeze(0)
|
# testImage = testImage.unsqueeze(0)
|
||||||
testImage = testImage.to(device)
|
# testImage = testImage.to(device)
|
||||||
|
|
||||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
# model.load_state_dict(torch.load('CNN_model.pth'))
|
||||||
model.to(device)
|
# model.to(device)
|
||||||
model.eval()
|
# model.eval()
|
||||||
|
|
||||||
testOutput = model(testImage)
|
# testOutput = model(testImage)
|
||||||
_, predicted = torch.max(testOutput, 1)
|
# _, predicted = torch.max(testOutput, 1)
|
||||||
predicted_class = train_set.classes[predicted.item()]
|
# predicted_class = train_set.classes[predicted.item()]
|
||||||
print(f'The predicted class is: {predicted_class}')
|
# print(f'The predicted class is: {predicted_class}')
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -13,6 +13,7 @@ from plant import Plant
|
|||||||
from bfs import graphsearch, Istate, succ
|
from bfs import graphsearch, Istate, succ
|
||||||
from astar import a_star
|
from astar import a_star
|
||||||
from NN.neural_network import load_model, load_image, guess_image
|
from NN.neural_network import load_model, load_image, guess_image
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||||
pygame.display.set_caption('Intelligent tractor')
|
pygame.display.set_caption('Intelligent tractor')
|
||||||
|
@ -21,17 +21,17 @@ class Plant:
|
|||||||
|
|
||||||
def update_name(self, predicted_class):
|
def update_name(self, predicted_class):
|
||||||
if predicted_class == "Apple":
|
if predicted_class == "Apple":
|
||||||
self.name = "Apple"
|
self.name = "apple"
|
||||||
self.plant_type = "fruit"
|
self.plant_type = 'fruit'
|
||||||
|
|
||||||
elif predicted_class == "Radish":
|
elif predicted_class == "Radish":
|
||||||
self.name = "Radish"
|
self.name = "radish"
|
||||||
self.plant_type = "vegetable"
|
self.plant_type = 'vegetable'
|
||||||
|
|
||||||
elif predicted_class == "Cauliflower":
|
elif predicted_class == "Cauliflower":
|
||||||
self.name = "Cauliflower"
|
self.name = "cauliflower"
|
||||||
self.plant_type = "vegetable"
|
self.plant_type = 'vegetable'
|
||||||
|
|
||||||
elif predicted_class == "Wheat":
|
elif predicted_class == "Wheat":
|
||||||
self.name = "Wheat"
|
self.name = "wheat"
|
||||||
self.plant_type = "cereal"
|
self.plant_type = 'cereal'
|
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import time
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -54,6 +55,9 @@ class Tile:
|
|||||||
def display_photo(self):
|
def display_photo(self):
|
||||||
image_path = self.photo
|
image_path = self.photo
|
||||||
img = Image.open(image_path)
|
img = Image.open(image_path)
|
||||||
|
plt.ion()
|
||||||
plt.imshow(img)
|
plt.imshow(img)
|
||||||
plt.axis('off')
|
plt.axis('off')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
time.sleep(5)
|
||||||
|
plt.close()
|
||||||
|
Loading…
Reference in New Issue
Block a user