test photos

This commit is contained in:
dardwo 2023-06-03 11:24:47 +02:00
parent 09071ecbe7
commit dbb637af6c

71
main.py
View File

@ -13,6 +13,77 @@ from src.ID3 import make_decision
import torch import torch
import cnn_model import cnn_model
def recognize_plants(plants_array):
checkpoint = torch.load(f'plants.model')
model = cnn_model.Net(num_classes=3)
model.load_state_dict(checkpoint)
model.eval()
img = ''
b=0
j=0
field_array_small = []
field_array_big = []
for i in range(11):
field_array_small = []
if b == 0:
for j in range(11):
if plants_array[j][i] == 'carrot':
img = 'assets/learning/test/carrot/' + str(random.randint(1, 200)) + '.jpg'
pred = cnn_model.prediction(img, model)
#show_plant_img(img)
elif plants_array[j][i] == 'potato':
img = 'assets/learning/test/potato/' + str(random.randint(1, 200)) + '.jpg'
pred = cnn_model.prediction(img, model)
# show_plant_img(img)
elif plants_array[j][i] == 'wheat':
img = 'assets/learning/test/wheat/' + str(random.randint(1, 200)) + '.jpg'
pred = cnn_model.prediction(img, model)
# show_plant_img(img)
else:
pred = 'none'
field_array_small.append(pred)
print(i,',', j,'-',pred)
# agent_movement(['f'], agent, fields_for_movement, fields_for_astar)
# agent_movement(['r','f','r'], agent, fields_for_movement, fields_for_astar)
field_array_big.append(field_array_small)
else:
for j in range(10,-1,-1):
if plants_array[j][i] == 'carrot':
img = 'assets/learning/test/carrot/' + str(random.randint(1, 200)) + '.jpg'
pred = cnn_model.prediction(img, model)
# show_plant_img(img)
elif plants_array[j][i] == 'potato':
img = 'assets/learning/test/potato/' + str(random.randint(1, 200)) + '.jpg'
pred = cnn_model.prediction(img, model)
# show_plant_img(img)
elif plants_array[j][i] == 'wheat':
img = 'assets/learning/test/wheat/' + str(random.randint(1, 200)) + '.jpg'
pred = cnn_model.prediction(img, model)
# show_plant_img(img)
else:
pred = 'none'
field_array_small.append(pred)
print(i,',', j,'-',pred)
# agent_movement(['f'], agent, fields_for_movement, fields_for_astar)
field_array_small = field_array_small[::-1]
field_array_big.append(field_array_small)
# agent_movement(['l','f','l'], agent, fields_for_movement, fields_for_astar)
if b==0:
b=1
else:
b=0
correct = 0
incorrect = 0
for i in range(11):
for j in range(11):
if plants_array[i][j]=='none':
continue
else:
if plants_array[i][j]==field_array_big[j][i]:
correct+=1
else:
incorrect+=1
print("Accuracy: ",correct/(correct+incorrect)*100,'%')