diff --git a/App.py b/App.py index b437928..3e0da06 100644 --- a/App.py +++ b/App.py @@ -124,13 +124,15 @@ def init_demo(): #Demo purpose if (newModel): print_to_console("uczenie sieci neuronowej") model = neuralnetwork.trainNewModel() - neuralnetwork.saveModel(model) + neuralnetwork.saveModel(model, 'model2.pth') + print_to_console("sieć nuronowa nauczona") print('model został wygenerowany') else: model = neuralnetwork.loadModel('model.pth') print_to_console("model został załądowny") testset = neuralnetwork.getDataset(False) print(neuralnetwork.accuracy(model, testset)) + traktor.snake_move_predict_plant(pole, model) start_flag=False # demo_move() old_info=get_info(old_info) diff --git a/Image.py b/Image.py index 732d622..a2adfe5 100644 --- a/Image.py +++ b/Image.py @@ -64,4 +64,6 @@ def getRandomImageFromDataBase(): random_image = random.choice(files) imgPath = os.path.join(folderPath, random_image) - return pygame.image.load(imgPath), label, imgPath + image = pygame.image.load(imgPath) + image=pygame.transform.scale(image,(dCon.CUBE_SIZE,dCon.CUBE_SIZE)) + return image, label, imgPath diff --git a/Slot.py b/Slot.py index 6f8e30c..69a971a 100644 --- a/Slot.py +++ b/Slot.py @@ -46,6 +46,7 @@ class Slot: self.plant=Roslina.Roslina(plant_name) else: self.plant_image, self.label, self.imagePath = self.random_plant_dataset() + # print(self.plant_image) self.plant=Roslina.Roslina(self.label) self.set_image() diff --git a/Tractor.py b/Tractor.py index 71a0ee9..4a5d10e 100644 --- a/Tractor.py +++ b/Tractor.py @@ -9,6 +9,7 @@ import Osprzet import Node import Condition import Drzewo +import neuralnetwork as nn condition=Condition.Condition() drzewo=Drzewo.Drzewo() @@ -191,6 +192,23 @@ class Tractor: self.turn_left() print("podlanych slotów: ", str(counter)) + def snake_move_predict_plant(self, pole, model): + initPos = (self.slot.x_axis, self.slot.y_axis) + for i in range(initPos[1], dCon.NUM_Y): + for j in range(initPos[0], dCon.NUM_X): + if self.slot.imagePath != None: + predictedLabel = nn.predictLabel(self.slot.imagePath, model) + print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real: ", self.slot.label, "predicted: ", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect") + self.move_forward(pole, False) + if i % 2 == 0 and i != dCon.NUM_Y - 1: + self.turn_right() + self.move_forward(pole, False) + self.turn_right() + elif i != dCon.NUM_Y - 1: + self.turn_left() + self.move_forward(pole, False) + self.turn_left() + def snake_move(self,pole,x,y): next_slot_coordinates=(x,y) if(self.do_move_if_valid(pole,next_slot_coordinates)): diff --git a/model.pth b/model.pth index 90ece27..c177c79 100644 Binary files a/model.pth and b/model.pth differ diff --git a/neuralnetwork.py b/neuralnetwork.py index 352f54a..aa21ffa 100644 --- a/neuralnetwork.py +++ b/neuralnetwork.py @@ -66,8 +66,8 @@ def getModel(): ).to(device) return model -def saveModel(model): - torch.save(model.state_dict(), 'model.pth') +def saveModel(model, path): + torch.save(model.state_dict(), path) def loadModel(path): model = getModel() @@ -91,12 +91,6 @@ def predictLabel(imagePath, model): predicted_class = torch.argmax(output).item() return labels[predicted_class] -def predictLabel(image, model): - image = preprocess_image(image) - with torch.no_grad(): - model.eval() # Ustawienie modelu w tryb ewaluacji - output = model(image) - # Znalezienie indeksu klasy o największej wartości prawdopodobieństwa predicted_class = torch.argmax(output).item() return labels[predicted_class]