From dbb637af6c11df9580d243d32c495c9dee2e457b Mon Sep 17 00:00:00 2001 From: dardwo Date: Sat, 3 Jun 2023 11:24:47 +0200 Subject: [PATCH] test photos --- main.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/main.py b/main.py index 7e85fc2..09cbbf5 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,77 @@ from src.ID3 import make_decision import torch import cnn_model +def recognize_plants(plants_array): + checkpoint = torch.load(f'plants.model') + model = cnn_model.Net(num_classes=3) + model.load_state_dict(checkpoint) + model.eval() + img = '' + b=0 + j=0 + field_array_small = [] + field_array_big = [] + for i in range(11): + field_array_small = [] + if b == 0: + for j in range(11): + if plants_array[j][i] == 'carrot': + img = 'assets/learning/test/carrot/' + str(random.randint(1, 200)) + '.jpg' + pred = cnn_model.prediction(img, model) + #show_plant_img(img) + elif plants_array[j][i] == 'potato': + img = 'assets/learning/test/potato/' + str(random.randint(1, 200)) + '.jpg' + pred = cnn_model.prediction(img, model) + # show_plant_img(img) + elif plants_array[j][i] == 'wheat': + img = 'assets/learning/test/wheat/' + str(random.randint(1, 200)) + '.jpg' + pred = cnn_model.prediction(img, model) + # show_plant_img(img) + else: + pred = 'none' + field_array_small.append(pred) + print(i,',', j,'-',pred) + # agent_movement(['f'], agent, fields_for_movement, fields_for_astar) + # agent_movement(['r','f','r'], agent, fields_for_movement, fields_for_astar) + field_array_big.append(field_array_small) + else: + for j in range(10,-1,-1): + if plants_array[j][i] == 'carrot': + img = 'assets/learning/test/carrot/' + str(random.randint(1, 200)) + '.jpg' + pred = cnn_model.prediction(img, model) + # show_plant_img(img) + elif plants_array[j][i] == 'potato': + img = 'assets/learning/test/potato/' + str(random.randint(1, 200)) + '.jpg' + pred = cnn_model.prediction(img, model) + # show_plant_img(img) + elif plants_array[j][i] == 'wheat': + img = 'assets/learning/test/wheat/' + str(random.randint(1, 200)) + '.jpg' + pred = cnn_model.prediction(img, model) + # show_plant_img(img) + else: + pred = 'none' + field_array_small.append(pred) + print(i,',', j,'-',pred) + # agent_movement(['f'], agent, fields_for_movement, fields_for_astar) + field_array_small = field_array_small[::-1] + field_array_big.append(field_array_small) + # agent_movement(['l','f','l'], agent, fields_for_movement, fields_for_astar) + if b==0: + b=1 + else: + b=0 + correct = 0 + incorrect = 0 + for i in range(11): + for j in range(11): + if plants_array[i][j]=='none': + continue + else: + if plants_array[i][j]==field_array_big[j][i]: + correct+=1 + else: + incorrect+=1 + print("Accuracy: ",correct/(correct+incorrect)*100,'%')