neural_network #25

Merged
s481834 merged 12 commits from neural_network into refactor 2024-06-04 13:23:20 +02:00
4 changed files with 10 additions and 3 deletions
Showing only changes of commit ccd1cb2359 - Show all commits

4
App.py
View File

@ -21,7 +21,7 @@ if bfs3_flag or Astar or Astar2:
Pole.stoneFlag = True Pole.stoneFlag = True
TreeFlag=False TreeFlag=False
nnFlag=True nnFlag=True
newModel=True newModel=False
pygame.init() pygame.init()
show_console=True show_console=True
@ -124,7 +124,7 @@ def init_demo(): #Demo purpose
if (newModel): if (newModel):
print_to_console("uczenie sieci neuronowej") print_to_console("uczenie sieci neuronowej")
model = neuralnetwork.trainNewModel() model = neuralnetwork.trainNewModel()
neuralnetwork.saveModel(model, 'model2.pth') neuralnetwork.saveModel(model, 'model.pth')
print_to_console("sieć nuronowa nauczona") print_to_console("sieć nuronowa nauczona")
print('model został wygenerowany') print('model został wygenerowany')
else: else:

View File

@ -202,7 +202,12 @@ class Tractor:
quit() quit()
if self.slot.imagePath != None: if self.slot.imagePath != None:
predictedLabel = nn.predictLabel(self.slot.imagePath, model) predictedLabel = nn.predictLabel(self.slot.imagePath, model)
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real:", self.slot.label, "predicted:", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect", 'nawożę za pomocą:', nn.fertilizer[predictedLabel])
# print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real:", self.slot.label, "predicted:", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect", 'nawożę za pomocą:', nn.fertilizer[predictedLabel])
if str(self.slot.label) != str(predictedLabel):
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis))
print("real: ", str(self.slot.label), "\tpredicted: ", str(predictedLabel), "\n")
if self.slot.label != predictedLabel: if self.slot.label != predictedLabel:
self.slot.mark_visited() self.slot.mark_visited()
count += 1 count += 1

BIN
model.pth Normal file

Binary file not shown.

View File

@ -15,6 +15,8 @@ fertilizer = {labels[0]: 'kompost', labels[1]: 'saletra amonowa', labels[2]: 'su
torch.manual_seed(42) torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps") if torch.backends.mps.is_available() else torch.device('cpu')
# print(device)
def getTransformation(): def getTransformation():
transform=transforms.Compose([ transform=transforms.Compose([