losowanie zdjecia dla slotu, poruszanie sie traktora, uzycie sieci
This commit is contained in:
parent
f38a52e135
commit
3c6e79a1fd
4
App.py
4
App.py
@ -124,13 +124,15 @@ def init_demo(): #Demo purpose
|
|||||||
if (newModel):
|
if (newModel):
|
||||||
print_to_console("uczenie sieci neuronowej")
|
print_to_console("uczenie sieci neuronowej")
|
||||||
model = neuralnetwork.trainNewModel()
|
model = neuralnetwork.trainNewModel()
|
||||||
neuralnetwork.saveModel(model)
|
neuralnetwork.saveModel(model, 'model2.pth')
|
||||||
|
print_to_console("sieć nuronowa nauczona")
|
||||||
print('model został wygenerowany')
|
print('model został wygenerowany')
|
||||||
else:
|
else:
|
||||||
model = neuralnetwork.loadModel('model.pth')
|
model = neuralnetwork.loadModel('model.pth')
|
||||||
print_to_console("model został załądowny")
|
print_to_console("model został załądowny")
|
||||||
testset = neuralnetwork.getDataset(False)
|
testset = neuralnetwork.getDataset(False)
|
||||||
print(neuralnetwork.accuracy(model, testset))
|
print(neuralnetwork.accuracy(model, testset))
|
||||||
|
traktor.snake_move_predict_plant(pole, model)
|
||||||
start_flag=False
|
start_flag=False
|
||||||
# demo_move()
|
# demo_move()
|
||||||
old_info=get_info(old_info)
|
old_info=get_info(old_info)
|
||||||
|
4
Image.py
4
Image.py
@ -64,4 +64,6 @@ def getRandomImageFromDataBase():
|
|||||||
random_image = random.choice(files)
|
random_image = random.choice(files)
|
||||||
imgPath = os.path.join(folderPath, random_image)
|
imgPath = os.path.join(folderPath, random_image)
|
||||||
|
|
||||||
return pygame.image.load(imgPath), label, imgPath
|
image = pygame.image.load(imgPath)
|
||||||
|
image=pygame.transform.scale(image,(dCon.CUBE_SIZE,dCon.CUBE_SIZE))
|
||||||
|
return image, label, imgPath
|
||||||
|
1
Slot.py
1
Slot.py
@ -46,6 +46,7 @@ class Slot:
|
|||||||
self.plant=Roslina.Roslina(plant_name)
|
self.plant=Roslina.Roslina(plant_name)
|
||||||
else:
|
else:
|
||||||
self.plant_image, self.label, self.imagePath = self.random_plant_dataset()
|
self.plant_image, self.label, self.imagePath = self.random_plant_dataset()
|
||||||
|
# print(self.plant_image)
|
||||||
self.plant=Roslina.Roslina(self.label)
|
self.plant=Roslina.Roslina(self.label)
|
||||||
self.set_image()
|
self.set_image()
|
||||||
|
|
||||||
|
18
Tractor.py
18
Tractor.py
@ -9,6 +9,7 @@ import Osprzet
|
|||||||
import Node
|
import Node
|
||||||
import Condition
|
import Condition
|
||||||
import Drzewo
|
import Drzewo
|
||||||
|
import neuralnetwork as nn
|
||||||
|
|
||||||
condition=Condition.Condition()
|
condition=Condition.Condition()
|
||||||
drzewo=Drzewo.Drzewo()
|
drzewo=Drzewo.Drzewo()
|
||||||
@ -191,6 +192,23 @@ class Tractor:
|
|||||||
self.turn_left()
|
self.turn_left()
|
||||||
print("podlanych slotów: ", str(counter))
|
print("podlanych slotów: ", str(counter))
|
||||||
|
|
||||||
|
def snake_move_predict_plant(self, pole, model):
|
||||||
|
initPos = (self.slot.x_axis, self.slot.y_axis)
|
||||||
|
for i in range(initPos[1], dCon.NUM_Y):
|
||||||
|
for j in range(initPos[0], dCon.NUM_X):
|
||||||
|
if self.slot.imagePath != None:
|
||||||
|
predictedLabel = nn.predictLabel(self.slot.imagePath, model)
|
||||||
|
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real: ", self.slot.label, "predicted: ", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect")
|
||||||
|
self.move_forward(pole, False)
|
||||||
|
if i % 2 == 0 and i != dCon.NUM_Y - 1:
|
||||||
|
self.turn_right()
|
||||||
|
self.move_forward(pole, False)
|
||||||
|
self.turn_right()
|
||||||
|
elif i != dCon.NUM_Y - 1:
|
||||||
|
self.turn_left()
|
||||||
|
self.move_forward(pole, False)
|
||||||
|
self.turn_left()
|
||||||
|
|
||||||
def snake_move(self,pole,x,y):
|
def snake_move(self,pole,x,y):
|
||||||
next_slot_coordinates=(x,y)
|
next_slot_coordinates=(x,y)
|
||||||
if(self.do_move_if_valid(pole,next_slot_coordinates)):
|
if(self.do_move_if_valid(pole,next_slot_coordinates)):
|
||||||
|
@ -66,8 +66,8 @@ def getModel():
|
|||||||
).to(device)
|
).to(device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def saveModel(model):
|
def saveModel(model, path):
|
||||||
torch.save(model.state_dict(), 'model.pth')
|
torch.save(model.state_dict(), path)
|
||||||
|
|
||||||
def loadModel(path):
|
def loadModel(path):
|
||||||
model = getModel()
|
model = getModel()
|
||||||
@ -91,12 +91,6 @@ def predictLabel(imagePath, model):
|
|||||||
predicted_class = torch.argmax(output).item()
|
predicted_class = torch.argmax(output).item()
|
||||||
return labels[predicted_class]
|
return labels[predicted_class]
|
||||||
|
|
||||||
def predictLabel(image, model):
|
|
||||||
image = preprocess_image(image)
|
|
||||||
with torch.no_grad():
|
|
||||||
model.eval() # Ustawienie modelu w tryb ewaluacji
|
|
||||||
output = model(image)
|
|
||||||
|
|
||||||
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
||||||
predicted_class = torch.argmax(output).item()
|
predicted_class = torch.argmax(output).item()
|
||||||
return labels[predicted_class]
|
return labels[predicted_class]
|
||||||
|
Loading…
Reference in New Issue
Block a user