neural_network #25

Merged
s481834 merged 12 commits from neural_network into refactor 2024-06-04 13:23:20 +02:00
3 changed files with 8 additions and 8 deletions
Showing only changes of commit 38c1adb054 - Show all commits

View File

@ -15,6 +15,7 @@ condition=Condition.Condition()
drzewo=Drzewo.Drzewo()
format_string = "{:<25}{:<25}{:<25}{:<10}{:<10}{:<10}{:<25}{:<15}{:<20}{:<10}{:<15}"
format_string_nn="{:<10}{:<20}{:<20}{:<15}{:<20}"
tab = [-1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
@ -193,6 +194,8 @@ class Tractor:
print("podlanych slotów: ", str(counter))
def snake_move_predict_plant(self, pole, model):
headers=['Coords','Real plant','Predicted plant','Result','Fertilizer']
print(format_string_nn.format(*headers))
initPos = (self.slot.x_axis, self.slot.y_axis)
count = 0
for i in range(initPos[1], dCon.NUM_Y):
@ -203,11 +206,8 @@ class Tractor:
if self.slot.imagePath != None:
predictedLabel = nn.predictLabel(self.slot.imagePath, model)
# print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real:", self.slot.label, "predicted:", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect", 'nawożę za pomocą:', nn.fertilizer[predictedLabel])
if str(self.slot.label) != str(predictedLabel):
print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis))
print("real: ", str(self.slot.label), "\tpredicted: ", str(predictedLabel), "\n")
#print(str("Coords: ({:02d}, {:02d})").format(self.slot.x_axis, self.slot.y_axis), "real:", self.slot.label, "predicted:", predictedLabel, "correct" if (self.slot.label == predictedLabel) else "incorrect", 'nawożę za pomocą:', nn.fertilizer[predictedLabel])
print(format_string_nn.format(f"{self.slot.x_axis,self.slot.y_axis}",self.slot.label,predictedLabel,"correct" if (self.slot.label == predictedLabel) else "incorrect",nn.fertilizer[predictedLabel]))
if self.slot.label != predictedLabel:
self.slot.mark_visited()
count += 1
@ -220,7 +220,7 @@ class Tractor:
self.turn_left()
self.move_forward(pole, False)
self.turn_left()
print(f"źle nawiezionych roślin: {count}")
print(f"Dobrze nawiezionych roślin: {20*12-count}, źle nawiezionych roślin: {count}")
def snake_move(self,pole,x,y):
next_slot_coordinates=(x,y)

Binary file not shown.

View File

@ -10,7 +10,7 @@ import random
imageSize = (128, 128)
labels = ['carrot','corn', 'potato', 'tomato'] # musi być w kolejności alfabetycznej
fertilizer = {labels[0]: 'kompost', labels[1]: 'saletra amonowa', labels[2]: 'superfosfat'}
fertilizer = {labels[0]: 'kompost', labels[1]: 'saletra amonowa', labels[2]: 'superfosfat', labels[3]:'obornik kurzy'}
torch.manual_seed(42)
@ -59,7 +59,7 @@ def accuracy(model, dataset):
return correct.float() / len(dataset)
def getModel():
hidden_size = 300
hidden_size = 500
model = nn.Sequential(
nn.Linear(imageSize[0] * imageSize[1] * 3, hidden_size),
nn.ReLU(),