add decision tree implementation
This commit is contained in:
parent
233f218899
commit
4c9682729e
4
01.csv
4
01.csv
@ -171,8 +171,8 @@ can it get to the next point,will it be able to get to the gas station,will it b
|
||||
1,0,1,0,1,0,0,1,5
|
||||
1,0,1,0,1,0,1,0,5
|
||||
1,0,1,0,1,0,1,1,5
|
||||
1,0,1,0,1,1,0,0,3
|
||||
1,0,1,0,1,1,0,1,3
|
||||
1,0,1,0,1,1,0,0,5
|
||||
1,0,1,0,1,1,0,1,5
|
||||
1,0,1,0,1,1,1,0,4
|
||||
1,0,1,0,1,1,1,1,4
|
||||
1,0,1,1,0,0,0,0,1
|
||||
|
|
15
IC3.py
15
IC3.py
@ -117,16 +117,21 @@ def predict(tree, instance):
|
||||
else:
|
||||
root_node = next(iter(tree))
|
||||
feature_value = instance[root_node]
|
||||
print(root_node)
|
||||
print(feature_value)
|
||||
if feature_value in tree[root_node]:
|
||||
return predict(tree[root_node][feature_value], instance)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def evaluate(tree, test_data_m, label):
|
||||
correct_preditct = 0
|
||||
wrong_preditct = 0
|
||||
for index, row in test_data_m.iterrows():
|
||||
print(test_data_m.iloc[index])
|
||||
result = predict(tree, test_data_m.iloc[index])
|
||||
print()
|
||||
if result == test_data_m[label].iloc[index]:
|
||||
correct_preditct += 1
|
||||
else:
|
||||
@ -147,8 +152,10 @@ class NpEncoder(json.JSONEncoder):
|
||||
|
||||
tree = id3(train_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER')
|
||||
# print(tree)
|
||||
json_str = json.dumps(tree, indent=2, cls=NpEncoder)
|
||||
print(json_str)
|
||||
# json_str = json.dumps(tree, indent=2, cls=NpEncoder)
|
||||
# print(json_str)
|
||||
|
||||
accuracy = evaluate(tree, test_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER')
|
||||
# print(accuracy)
|
||||
|
||||
|
||||
accuracy = evaluate(tree, test_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER') #evaluating the test dataset
|
||||
print(accuracy)
|
||||
|
143
field.py
143
field.py
@ -4,6 +4,7 @@ from heapq import *
|
||||
from enum import Enum, IntEnum
|
||||
from queue import PriorityQueue
|
||||
from collections import deque
|
||||
from IC3 import tree
|
||||
|
||||
import pygame
|
||||
|
||||
@ -108,9 +109,10 @@ def draw_interface():
|
||||
elif event.type == pygame.MOUSEBUTTONDOWN:
|
||||
startpoint = (tractor.x, tractor.y, tractor.direction)
|
||||
endpoint = get_click_mouse_pos()
|
||||
a, c = graph1.a_star(startpoint, endpoint)
|
||||
b = getRoad(startpoint, c, a)
|
||||
movement(tractor, grid, b)
|
||||
decisionTree(startpoint, endpoint, tractor, grid, graph1)
|
||||
# a, c = graph1.a_star(startpoint, endpoint)
|
||||
# b = getRoad(startpoint, c, a)
|
||||
# movement(tractor, grid, b)
|
||||
updateDisplay(tractor, grid)
|
||||
|
||||
|
||||
@ -289,7 +291,28 @@ def movement(tractor: Tractor, grid: Grid, road):
|
||||
else:
|
||||
tractor.rot_center(Direction.LEFT)
|
||||
updateDisplay(tractor, grid)
|
||||
|
||||
def getCost(tractor: Tractor, grid: Grid, road):
|
||||
n = len(road)
|
||||
cost = 0
|
||||
for i in range(n - 1):
|
||||
aA = road[i]
|
||||
bB = road[i + 1]
|
||||
if aA[0] != bB[0]:
|
||||
if grid.grid[bB[0]][bB[1]] == types.ROCK:
|
||||
cost += 12
|
||||
else:
|
||||
cost += 2
|
||||
if aA[1] != bB[1]:
|
||||
if grid.grid[bB[0]][bB[1]] == types.ROCK:
|
||||
cost += 12
|
||||
else:
|
||||
cost += 2
|
||||
if aA[2] != bB[2]:
|
||||
if (bB[2].value - aA[2].value == 1) or (bB[2].value - aA[2].value == -3):
|
||||
cost +=1
|
||||
else:
|
||||
cost +=1
|
||||
return cost
|
||||
|
||||
def getRoad(start, goal, visited):
|
||||
arr = []
|
||||
@ -340,3 +363,115 @@ def updateDisplay(tractor: Tractor, grid: Grid):
|
||||
|
||||
pygame.time.Clock().tick(60)
|
||||
|
||||
def decisionTree(startpoint, endpoint, tractor, grid, graph1):
|
||||
|
||||
one="can it get to the next point"
|
||||
two="will it be able to get to the gas station"
|
||||
three="will it be able to get to the gas station after arriving at the next point"
|
||||
four="will it be able to take the next vegetable to the tractor storage"
|
||||
five="will it be able to get to the vegetable warehouse"
|
||||
six="will it be able to get to the gas station after it arrives at the vegetable warehouse"
|
||||
seven="is the vegetable warehouse closed"
|
||||
eight="is the gas station closed"
|
||||
arr = []
|
||||
arr.append(one)
|
||||
arr.append(two)
|
||||
arr.append(three)
|
||||
arr.append(four)
|
||||
arr.append(five)
|
||||
arr.append(six)
|
||||
arr.append(seven)
|
||||
arr.append(eight)
|
||||
|
||||
|
||||
a1, c1 = graph1.a_star(startpoint, endpoint)
|
||||
b1 = getRoad(startpoint, c1, a1)
|
||||
cost1 = getCost(tractor, grid, b1)
|
||||
|
||||
a2, c2 = graph1.a_star(startpoint, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT))
|
||||
b2 = getRoad(startpoint, c2, a2)
|
||||
cost2 = getCost(tractor, grid, b2)
|
||||
|
||||
a3, c3 = graph1.a_star(startpoint, (SKLEP_POINT[0], SKLEP_POINT[1], Direction.RIGHT))
|
||||
b3 = getRoad(startpoint, c3, a3)
|
||||
cost3 = getCost(tractor, grid, b3)
|
||||
|
||||
a4, c4 = graph1.a_star(c1, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT))
|
||||
b4 = getRoad(c1, c4, a4)
|
||||
cost4 = getCost(tractor, grid, b4)
|
||||
|
||||
a5, c5 = graph1.a_star(c3, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT))
|
||||
b5 = getRoad(c3, c5, a5)
|
||||
cost5 = getCost(tractor, grid, b5)
|
||||
|
||||
|
||||
if tractor.gas - cost1 > 0:
|
||||
can_it_get_to_the_next_point = 1
|
||||
else:
|
||||
can_it_get_to_the_next_point = 0
|
||||
|
||||
if tractor.gas - cost2 > 0:
|
||||
will_it_be_able_to_get_to_the_gas_station = 1
|
||||
else:
|
||||
will_it_be_able_to_get_to_the_gas_station = 0
|
||||
|
||||
if tractor.gas - cost1 - cost4 > 0:
|
||||
will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point = 1
|
||||
else:
|
||||
will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point = 0
|
||||
|
||||
if grid.grid[endpoint[0]][endpoint[1]] in vegetables:
|
||||
if tractor.collected_vegetables[grid.grid[endpoint[0]][endpoint[1]]] < 5:
|
||||
will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 1
|
||||
else:
|
||||
will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 0
|
||||
else:
|
||||
will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 1
|
||||
|
||||
if tractor.gas - cost3 > 0:
|
||||
will_it_be_able_to_get_to_the_vegetable_warehouse=1
|
||||
else:
|
||||
will_it_be_able_to_get_to_the_vegetable_warehouse=0
|
||||
|
||||
if tractor.gas - cost3 - cost5 > 0:
|
||||
will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse=1
|
||||
else:
|
||||
will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse = 0
|
||||
|
||||
is_the_vegetable_warehouse_closed=0
|
||||
is_the_gas_station_closed=0
|
||||
|
||||
|
||||
brr = []
|
||||
brr.append(can_it_get_to_the_next_point)
|
||||
brr.append(will_it_be_able_to_get_to_the_gas_station)
|
||||
brr.append(will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point)
|
||||
brr.append(will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage)
|
||||
brr.append(will_it_be_able_to_get_to_the_vegetable_warehouse)
|
||||
brr.append(will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse)
|
||||
brr.append(is_the_vegetable_warehouse_closed)
|
||||
brr.append(is_the_gas_station_closed)
|
||||
|
||||
def predict(tree):
|
||||
if not isinstance(tree, dict):
|
||||
return tree
|
||||
else:
|
||||
root_node = next(iter(tree))
|
||||
feature_value = brr[arr.index(root_node)]
|
||||
if feature_value in tree[root_node]:
|
||||
return predict(tree[root_node][feature_value])
|
||||
else:
|
||||
return None
|
||||
decision = predict(tree)
|
||||
print(decision)
|
||||
if decision==1:
|
||||
movement(tractor, grid, b1)
|
||||
if decision==2:
|
||||
movement(tractor, grid, b2)
|
||||
if decision==3:
|
||||
movement(tractor, grid, b3)
|
||||
if decision==4:
|
||||
print("waiting")
|
||||
if decision==5:
|
||||
print("GAME OVER")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user