losowanie zdjecia dla slotu, poruszanie sie traktora, uzycie sieci
This commit is contained in:
parent
f38a52e135
commit
3c6e79a1fd
4
App.py
4
App.py
@ -124,13 +124,15 @@ def init_demo(): #Demo purpose
|
||||
if (newModel):
|
||||
print_to_console("uczenie sieci neuronowej")
|
||||
model = neuralnetwork.trainNewModel()
|
||||
neuralnetwork.saveModel(model)
|
||||
neuralnetwork.saveModel(model, 'model2.pth')
|
||||
print_to_console("sieć nuronowa nauczona")
|
||||
print('model został wygenerowany')
|
||||
else:
|
||||
model = neuralnetwork.loadModel('model.pth')
|
||||
print_to_console("model został załądowny")
|
||||
testset = neuralnetwork.getDataset(False)
|
||||
print(neuralnetwork.accuracy(model, testset))
|
||||
traktor.snake_move_predict_plant(pole, model)
|
||||
start_flag=False
|
||||
# demo_move()
|
||||
old_info=get_info(old_info)
|
||||
|
4
Image.py
4
Image.py
@ -64,4 +64,6 @@ def getRandomImageFromDataBase():
|
||||
random_image = random.choice(files)
|
||||
imgPath = os.path.join(folderPath, random_image)
|
||||
|
||||
return pygame.image.load(imgPath), label, imgPath
|
||||
image = pygame.image.load(imgPath)
|
||||
image=pygame.transform.scale(image,(dCon.CUBE_SIZE,dCon.CUBE_SIZE))
|
||||
return image, label, imgPath
|
||||
|
1
Slot.py
1
Slot.py
@ -46,6 +46,7 @@ class Slot:
|
||||
self.plant=Roslina.Roslina(plant_name)
|
||||
else:
|
||||
self.plant_image, self.label, self.imagePath = self.random_plant_dataset()
|
||||
# print(self.plant_image)
|
||||
self.plant=Roslina.Roslina(self.label)
|
||||
self.set_image()
|
||||
|
||||
|
18
Tractor.py
18
Tractor.py
@ -9,6 +9,7 @@ import Osprzet
|
||||
import Node
|
||||
import Condition
|
||||
import Drzewo
|
||||
import neuralnetwork as nn
|
||||
|
||||
condition=Condition.Condition()
|
||||
drzewo=Drzewo.Drzewo()
|
||||
@ -191,6 +192,23 @@ class Tractor:
|
||||
self.turn_left()
|
||||
print("podlanych slotów: ", str(counter))
|
||||
|
||||
def snake_move_predict_plant(self, pole, model):
|
||||
initPos = (self.slot.x_axis, self.slot.y_axis)
|
||||
for i in range(initPos[1], dCon.NUM_Y):
|
||||
for j in range(initPos[0], dCon.NUM_X):
|
||||
if self.slot.imagePath != None:
|
||||
predictedLabel = nn.predictLabel(self.slot.imagePath, model)
|
||||
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real: ", self.slot.label, "predicted: ", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect")
|
||||
self.move_forward(pole, False)
|
||||
if i % 2 == 0 and i != dCon.NUM_Y - 1:
|
||||
self.turn_right()
|
||||
self.move_forward(pole, False)
|
||||
self.turn_right()
|
||||
elif i != dCon.NUM_Y - 1:
|
||||
self.turn_left()
|
||||
self.move_forward(pole, False)
|
||||
self.turn_left()
|
||||
|
||||
def snake_move(self,pole,x,y):
|
||||
next_slot_coordinates=(x,y)
|
||||
if(self.do_move_if_valid(pole,next_slot_coordinates)):
|
||||
|
@ -66,8 +66,8 @@ def getModel():
|
||||
).to(device)
|
||||
return model
|
||||
|
||||
def saveModel(model):
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
def saveModel(model, path):
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
def loadModel(path):
|
||||
model = getModel()
|
||||
@ -91,12 +91,6 @@ def predictLabel(imagePath, model):
|
||||
predicted_class = torch.argmax(output).item()
|
||||
return labels[predicted_class]
|
||||
|
||||
def predictLabel(image, model):
|
||||
image = preprocess_image(image)
|
||||
with torch.no_grad():
|
||||
model.eval() # Ustawienie modelu w tryb ewaluacji
|
||||
output = model(image)
|
||||
|
||||
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
||||
predicted_class = torch.argmax(output).item()
|
||||
return labels[predicted_class]
|
||||
|
Loading…
Reference in New Issue
Block a user