neural_network #25

Merged
s481834 merged 12 commits from neural_network into refactor 2024-06-04 13:23:20 +02:00
6 changed files with 27 additions and 10 deletions
Showing only changes of commit 3c6e79a1fd - Show all commits

4
App.py
View File

@ -124,13 +124,15 @@ def init_demo(): #Demo purpose
if (newModel):
print_to_console("uczenie sieci neuronowej")
model = neuralnetwork.trainNewModel()
neuralnetwork.saveModel(model)
neuralnetwork.saveModel(model, 'model2.pth')
print_to_console("sieć nuronowa nauczona")
print('model został wygenerowany')
else:
model = neuralnetwork.loadModel('model.pth')
print_to_console("model został załądowny")
testset = neuralnetwork.getDataset(False)
print(neuralnetwork.accuracy(model, testset))
traktor.snake_move_predict_plant(pole, model)
start_flag=False
# demo_move()
old_info=get_info(old_info)

View File

@ -64,4 +64,6 @@ def getRandomImageFromDataBase():
random_image = random.choice(files)
imgPath = os.path.join(folderPath, random_image)
return pygame.image.load(imgPath), label, imgPath
image = pygame.image.load(imgPath)
image=pygame.transform.scale(image,(dCon.CUBE_SIZE,dCon.CUBE_SIZE))
return image, label, imgPath

View File

@ -46,6 +46,7 @@ class Slot:
self.plant=Roslina.Roslina(plant_name)
else:
self.plant_image, self.label, self.imagePath = self.random_plant_dataset()
# print(self.plant_image)
self.plant=Roslina.Roslina(self.label)
self.set_image()

View File

@ -9,6 +9,7 @@ import Osprzet
import Node
import Condition
import Drzewo
import neuralnetwork as nn
condition=Condition.Condition()
drzewo=Drzewo.Drzewo()
@ -191,6 +192,23 @@ class Tractor:
self.turn_left()
print("podlanych slotów: ", str(counter))
def snake_move_predict_plant(self, pole, model):
initPos = (self.slot.x_axis, self.slot.y_axis)
for i in range(initPos[1], dCon.NUM_Y):
for j in range(initPos[0], dCon.NUM_X):
if self.slot.imagePath != None:
predictedLabel = nn.predictLabel(self.slot.imagePath, model)
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real: ", self.slot.label, "predicted: ", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect")
self.move_forward(pole, False)
if i % 2 == 0 and i != dCon.NUM_Y - 1:
self.turn_right()
self.move_forward(pole, False)
self.turn_right()
elif i != dCon.NUM_Y - 1:
self.turn_left()
self.move_forward(pole, False)
self.turn_left()
def snake_move(self,pole,x,y):
next_slot_coordinates=(x,y)
if(self.do_move_if_valid(pole,next_slot_coordinates)):

BIN
model.pth

Binary file not shown.

View File

@ -66,8 +66,8 @@ def getModel():
).to(device)
return model
def saveModel(model):
torch.save(model.state_dict(), 'model.pth')
def saveModel(model, path):
torch.save(model.state_dict(), path)
def loadModel(path):
model = getModel()
@ -91,12 +91,6 @@ def predictLabel(imagePath, model):
predicted_class = torch.argmax(output).item()
return labels[predicted_class]
def predictLabel(image, model):
image = preprocess_image(image)
with torch.no_grad():
model.eval() # Ustawienie modelu w tryb ewaluacji
output = model(image)
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
predicted_class = torch.argmax(output).item()
return labels[predicted_class]