neural_network #13
71
main.py
71
main.py
@ -13,6 +13,77 @@ from src.ID3 import make_decision
|
|||||||
import torch
|
import torch
|
||||||
import cnn_model
|
import cnn_model
|
||||||
|
|
||||||
|
def recognize_plants(plants_array):
|
||||||
|
checkpoint = torch.load(f'plants.model')
|
||||||
|
model = cnn_model.Net(num_classes=3)
|
||||||
|
model.load_state_dict(checkpoint)
|
||||||
|
model.eval()
|
||||||
|
img = ''
|
||||||
|
b=0
|
||||||
|
j=0
|
||||||
|
field_array_small = []
|
||||||
|
field_array_big = []
|
||||||
|
for i in range(11):
|
||||||
|
field_array_small = []
|
||||||
|
if b == 0:
|
||||||
|
for j in range(11):
|
||||||
|
if plants_array[j][i] == 'carrot':
|
||||||
|
img = 'assets/learning/test/carrot/' + str(random.randint(1, 200)) + '.jpg'
|
||||||
|
pred = cnn_model.prediction(img, model)
|
||||||
|
#show_plant_img(img)
|
||||||
|
elif plants_array[j][i] == 'potato':
|
||||||
|
img = 'assets/learning/test/potato/' + str(random.randint(1, 200)) + '.jpg'
|
||||||
|
pred = cnn_model.prediction(img, model)
|
||||||
|
# show_plant_img(img)
|
||||||
|
elif plants_array[j][i] == 'wheat':
|
||||||
|
img = 'assets/learning/test/wheat/' + str(random.randint(1, 200)) + '.jpg'
|
||||||
|
pred = cnn_model.prediction(img, model)
|
||||||
|
# show_plant_img(img)
|
||||||
|
else:
|
||||||
|
pred = 'none'
|
||||||
|
field_array_small.append(pred)
|
||||||
|
print(i,',', j,'-',pred)
|
||||||
|
# agent_movement(['f'], agent, fields_for_movement, fields_for_astar)
|
||||||
|
# agent_movement(['r','f','r'], agent, fields_for_movement, fields_for_astar)
|
||||||
|
field_array_big.append(field_array_small)
|
||||||
|
else:
|
||||||
|
for j in range(10,-1,-1):
|
||||||
|
if plants_array[j][i] == 'carrot':
|
||||||
|
img = 'assets/learning/test/carrot/' + str(random.randint(1, 200)) + '.jpg'
|
||||||
|
pred = cnn_model.prediction(img, model)
|
||||||
|
# show_plant_img(img)
|
||||||
|
elif plants_array[j][i] == 'potato':
|
||||||
|
img = 'assets/learning/test/potato/' + str(random.randint(1, 200)) + '.jpg'
|
||||||
|
pred = cnn_model.prediction(img, model)
|
||||||
|
# show_plant_img(img)
|
||||||
|
elif plants_array[j][i] == 'wheat':
|
||||||
|
img = 'assets/learning/test/wheat/' + str(random.randint(1, 200)) + '.jpg'
|
||||||
|
pred = cnn_model.prediction(img, model)
|
||||||
|
# show_plant_img(img)
|
||||||
|
else:
|
||||||
|
pred = 'none'
|
||||||
|
field_array_small.append(pred)
|
||||||
|
print(i,',', j,'-',pred)
|
||||||
|
# agent_movement(['f'], agent, fields_for_movement, fields_for_astar)
|
||||||
|
field_array_small = field_array_small[::-1]
|
||||||
|
field_array_big.append(field_array_small)
|
||||||
|
# agent_movement(['l','f','l'], agent, fields_for_movement, fields_for_astar)
|
||||||
|
if b==0:
|
||||||
|
b=1
|
||||||
|
else:
|
||||||
|
b=0
|
||||||
|
correct = 0
|
||||||
|
incorrect = 0
|
||||||
|
for i in range(11):
|
||||||
|
for j in range(11):
|
||||||
|
if plants_array[i][j]=='none':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if plants_array[i][j]==field_array_big[j][i]:
|
||||||
|
correct+=1
|
||||||
|
else:
|
||||||
|
incorrect+=1
|
||||||
|
print("Accuracy: ",correct/(correct+incorrect)*100,'%')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user