add decision tree implementation

This commit is contained in:
s473554 2023-05-25 15:46:08 +02:00
parent 233f218899
commit 4c9682729e
3 changed files with 152 additions and 10 deletions

4
01.csv
View File

@ -171,8 +171,8 @@ can it get to the next point,will it be able to get to the gas station,will it b
1,0,1,0,1,0,0,1,5
1,0,1,0,1,0,1,0,5
1,0,1,0,1,0,1,1,5
1,0,1,0,1,1,0,0,3
1,0,1,0,1,1,0,1,3
1,0,1,0,1,1,0,0,5
1,0,1,0,1,1,0,1,5
1,0,1,0,1,1,1,0,4
1,0,1,0,1,1,1,1,4
1,0,1,1,0,0,0,0,1

1 can it get to the next point will it be able to get to the gas station will it be able to get to the gas station after arriving at the next point will it be able to take the next vegetable to the tractor storage will it be able to get to the vegetable warehouse will it be able to get to the gas station after it arrives at the vegetable warehouse is the vegetable warehouse closed is the gas station closed go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER
171 1 0 1 0 1 0 0 1 5
172 1 0 1 0 1 0 1 0 5
173 1 0 1 0 1 0 1 1 5
174 1 0 1 0 1 1 0 0 3 5
175 1 0 1 0 1 1 0 1 3 5
176 1 0 1 0 1 1 1 0 4
177 1 0 1 0 1 1 1 1 4
178 1 0 1 1 0 0 0 0 1

15
IC3.py
View File

@ -117,16 +117,21 @@ def predict(tree, instance):
else:
root_node = next(iter(tree))
feature_value = instance[root_node]
print(root_node)
print(feature_value)
if feature_value in tree[root_node]:
return predict(tree[root_node][feature_value], instance)
else:
return None
def evaluate(tree, test_data_m, label):
correct_preditct = 0
wrong_preditct = 0
for index, row in test_data_m.iterrows():
print(test_data_m.iloc[index])
result = predict(tree, test_data_m.iloc[index])
print()
if result == test_data_m[label].iloc[index]:
correct_preditct += 1
else:
@ -147,8 +152,10 @@ class NpEncoder(json.JSONEncoder):
tree = id3(train_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER')
# print(tree)
json_str = json.dumps(tree, indent=2, cls=NpEncoder)
print(json_str)
# json_str = json.dumps(tree, indent=2, cls=NpEncoder)
# print(json_str)
accuracy = evaluate(tree, test_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER')
# print(accuracy)
accuracy = evaluate(tree, test_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER') #evaluating the test dataset
print(accuracy)

143
field.py
View File

@ -4,6 +4,7 @@ from heapq import *
from enum import Enum, IntEnum
from queue import PriorityQueue
from collections import deque
from IC3 import tree
import pygame
@ -108,9 +109,10 @@ def draw_interface():
elif event.type == pygame.MOUSEBUTTONDOWN:
startpoint = (tractor.x, tractor.y, tractor.direction)
endpoint = get_click_mouse_pos()
a, c = graph1.a_star(startpoint, endpoint)
b = getRoad(startpoint, c, a)
movement(tractor, grid, b)
decisionTree(startpoint, endpoint, tractor, grid, graph1)
# a, c = graph1.a_star(startpoint, endpoint)
# b = getRoad(startpoint, c, a)
# movement(tractor, grid, b)
updateDisplay(tractor, grid)
@ -289,7 +291,28 @@ def movement(tractor: Tractor, grid: Grid, road):
else:
tractor.rot_center(Direction.LEFT)
updateDisplay(tractor, grid)
def getCost(tractor: Tractor, grid: Grid, road):
n = len(road)
cost = 0
for i in range(n - 1):
aA = road[i]
bB = road[i + 1]
if aA[0] != bB[0]:
if grid.grid[bB[0]][bB[1]] == types.ROCK:
cost += 12
else:
cost += 2
if aA[1] != bB[1]:
if grid.grid[bB[0]][bB[1]] == types.ROCK:
cost += 12
else:
cost += 2
if aA[2] != bB[2]:
if (bB[2].value - aA[2].value == 1) or (bB[2].value - aA[2].value == -3):
cost +=1
else:
cost +=1
return cost
def getRoad(start, goal, visited):
arr = []
@ -340,3 +363,115 @@ def updateDisplay(tractor: Tractor, grid: Grid):
pygame.time.Clock().tick(60)
def decisionTree(startpoint, endpoint, tractor, grid, graph1):
one="can it get to the next point"
two="will it be able to get to the gas station"
three="will it be able to get to the gas station after arriving at the next point"
four="will it be able to take the next vegetable to the tractor storage"
five="will it be able to get to the vegetable warehouse"
six="will it be able to get to the gas station after it arrives at the vegetable warehouse"
seven="is the vegetable warehouse closed"
eight="is the gas station closed"
arr = []
arr.append(one)
arr.append(two)
arr.append(three)
arr.append(four)
arr.append(five)
arr.append(six)
arr.append(seven)
arr.append(eight)
a1, c1 = graph1.a_star(startpoint, endpoint)
b1 = getRoad(startpoint, c1, a1)
cost1 = getCost(tractor, grid, b1)
a2, c2 = graph1.a_star(startpoint, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT))
b2 = getRoad(startpoint, c2, a2)
cost2 = getCost(tractor, grid, b2)
a3, c3 = graph1.a_star(startpoint, (SKLEP_POINT[0], SKLEP_POINT[1], Direction.RIGHT))
b3 = getRoad(startpoint, c3, a3)
cost3 = getCost(tractor, grid, b3)
a4, c4 = graph1.a_star(c1, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT))
b4 = getRoad(c1, c4, a4)
cost4 = getCost(tractor, grid, b4)
a5, c5 = graph1.a_star(c3, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT))
b5 = getRoad(c3, c5, a5)
cost5 = getCost(tractor, grid, b5)
if tractor.gas - cost1 > 0:
can_it_get_to_the_next_point = 1
else:
can_it_get_to_the_next_point = 0
if tractor.gas - cost2 > 0:
will_it_be_able_to_get_to_the_gas_station = 1
else:
will_it_be_able_to_get_to_the_gas_station = 0
if tractor.gas - cost1 - cost4 > 0:
will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point = 1
else:
will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point = 0
if grid.grid[endpoint[0]][endpoint[1]] in vegetables:
if tractor.collected_vegetables[grid.grid[endpoint[0]][endpoint[1]]] < 5:
will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 1
else:
will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 0
else:
will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 1
if tractor.gas - cost3 > 0:
will_it_be_able_to_get_to_the_vegetable_warehouse=1
else:
will_it_be_able_to_get_to_the_vegetable_warehouse=0
if tractor.gas - cost3 - cost5 > 0:
will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse=1
else:
will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse = 0
is_the_vegetable_warehouse_closed=0
is_the_gas_station_closed=0
brr = []
brr.append(can_it_get_to_the_next_point)
brr.append(will_it_be_able_to_get_to_the_gas_station)
brr.append(will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point)
brr.append(will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage)
brr.append(will_it_be_able_to_get_to_the_vegetable_warehouse)
brr.append(will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse)
brr.append(is_the_vegetable_warehouse_closed)
brr.append(is_the_gas_station_closed)
def predict(tree):
if not isinstance(tree, dict):
return tree
else:
root_node = next(iter(tree))
feature_value = brr[arr.index(root_node)]
if feature_value in tree[root_node]:
return predict(tree[root_node][feature_value])
else:
return None
decision = predict(tree)
print(decision)
if decision==1:
movement(tractor, grid, b1)
if decision==2:
movement(tractor, grid, b2)
if decision==3:
movement(tractor, grid, b3)
if decision==4:
print("waiting")
if decision==5:
print("GAME OVER")