From 4c9682729e8b6226cbd7820bc3fbb177d28539c2 Mon Sep 17 00:00:00 2001 From: s473554 Date: Thu, 25 May 2023 15:46:08 +0200 Subject: [PATCH] add decision tree implementation --- 01.csv | 4 +- IC3.py | 15 ++++-- field.py | 143 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 152 insertions(+), 10 deletions(-) diff --git a/01.csv b/01.csv index 8ccbdc5..1882101 100644 --- a/01.csv +++ b/01.csv @@ -171,8 +171,8 @@ can it get to the next point,will it be able to get to the gas station,will it b 1,0,1,0,1,0,0,1,5 1,0,1,0,1,0,1,0,5 1,0,1,0,1,0,1,1,5 -1,0,1,0,1,1,0,0,3 -1,0,1,0,1,1,0,1,3 +1,0,1,0,1,1,0,0,5 +1,0,1,0,1,1,0,1,5 1,0,1,0,1,1,1,0,4 1,0,1,0,1,1,1,1,4 1,0,1,1,0,0,0,0,1 diff --git a/IC3.py b/IC3.py index bbaa741..b4f6329 100644 --- a/IC3.py +++ b/IC3.py @@ -117,16 +117,21 @@ def predict(tree, instance): else: root_node = next(iter(tree)) feature_value = instance[root_node] + print(root_node) + print(feature_value) if feature_value in tree[root_node]: return predict(tree[root_node][feature_value], instance) else: return None + def evaluate(tree, test_data_m, label): correct_preditct = 0 wrong_preditct = 0 for index, row in test_data_m.iterrows(): + print(test_data_m.iloc[index]) result = predict(tree, test_data_m.iloc[index]) + print() if result == test_data_m[label].iloc[index]: correct_preditct += 1 else: @@ -147,8 +152,10 @@ class NpEncoder(json.JSONEncoder): tree = id3(train_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER') # print(tree) -json_str = json.dumps(tree, indent=2, cls=NpEncoder) -print(json_str) +# json_str = json.dumps(tree, indent=2, cls=NpEncoder) +# print(json_str) + +accuracy = evaluate(tree, test_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER') +# print(accuracy) + -accuracy = evaluate(tree, test_data_m, 'go to: 1)next veget. 2)gas station 3)warehouse 4)sleep 5)GAME OVER') #evaluating the test dataset -print(accuracy) diff --git a/field.py b/field.py index 7223942..bfdc107 100644 --- a/field.py +++ b/field.py @@ -4,6 +4,7 @@ from heapq import * from enum import Enum, IntEnum from queue import PriorityQueue from collections import deque +from IC3 import tree import pygame @@ -108,9 +109,10 @@ def draw_interface(): elif event.type == pygame.MOUSEBUTTONDOWN: startpoint = (tractor.x, tractor.y, tractor.direction) endpoint = get_click_mouse_pos() - a, c = graph1.a_star(startpoint, endpoint) - b = getRoad(startpoint, c, a) - movement(tractor, grid, b) + decisionTree(startpoint, endpoint, tractor, grid, graph1) + # a, c = graph1.a_star(startpoint, endpoint) + # b = getRoad(startpoint, c, a) + # movement(tractor, grid, b) updateDisplay(tractor, grid) @@ -289,7 +291,28 @@ def movement(tractor: Tractor, grid: Grid, road): else: tractor.rot_center(Direction.LEFT) updateDisplay(tractor, grid) - +def getCost(tractor: Tractor, grid: Grid, road): + n = len(road) + cost = 0 + for i in range(n - 1): + aA = road[i] + bB = road[i + 1] + if aA[0] != bB[0]: + if grid.grid[bB[0]][bB[1]] == types.ROCK: + cost += 12 + else: + cost += 2 + if aA[1] != bB[1]: + if grid.grid[bB[0]][bB[1]] == types.ROCK: + cost += 12 + else: + cost += 2 + if aA[2] != bB[2]: + if (bB[2].value - aA[2].value == 1) or (bB[2].value - aA[2].value == -3): + cost +=1 + else: + cost +=1 + return cost def getRoad(start, goal, visited): arr = [] @@ -340,3 +363,115 @@ def updateDisplay(tractor: Tractor, grid: Grid): pygame.time.Clock().tick(60) +def decisionTree(startpoint, endpoint, tractor, grid, graph1): + + one="can it get to the next point" + two="will it be able to get to the gas station" + three="will it be able to get to the gas station after arriving at the next point" + four="will it be able to take the next vegetable to the tractor storage" + five="will it be able to get to the vegetable warehouse" + six="will it be able to get to the gas station after it arrives at the vegetable warehouse" + seven="is the vegetable warehouse closed" + eight="is the gas station closed" + arr = [] + arr.append(one) + arr.append(two) + arr.append(three) + arr.append(four) + arr.append(five) + arr.append(six) + arr.append(seven) + arr.append(eight) + + + a1, c1 = graph1.a_star(startpoint, endpoint) + b1 = getRoad(startpoint, c1, a1) + cost1 = getCost(tractor, grid, b1) + + a2, c2 = graph1.a_star(startpoint, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT)) + b2 = getRoad(startpoint, c2, a2) + cost2 = getCost(tractor, grid, b2) + + a3, c3 = graph1.a_star(startpoint, (SKLEP_POINT[0], SKLEP_POINT[1], Direction.RIGHT)) + b3 = getRoad(startpoint, c3, a3) + cost3 = getCost(tractor, grid, b3) + + a4, c4 = graph1.a_star(c1, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT)) + b4 = getRoad(c1, c4, a4) + cost4 = getCost(tractor, grid, b4) + + a5, c5 = graph1.a_star(c3, (SPAWN_POINT[0], SPAWN_POINT[1], Direction.RIGHT)) + b5 = getRoad(c3, c5, a5) + cost5 = getCost(tractor, grid, b5) + + + if tractor.gas - cost1 > 0: + can_it_get_to_the_next_point = 1 + else: + can_it_get_to_the_next_point = 0 + + if tractor.gas - cost2 > 0: + will_it_be_able_to_get_to_the_gas_station = 1 + else: + will_it_be_able_to_get_to_the_gas_station = 0 + + if tractor.gas - cost1 - cost4 > 0: + will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point = 1 + else: + will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point = 0 + + if grid.grid[endpoint[0]][endpoint[1]] in vegetables: + if tractor.collected_vegetables[grid.grid[endpoint[0]][endpoint[1]]] < 5: + will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 1 + else: + will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 0 + else: + will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage = 1 + + if tractor.gas - cost3 > 0: + will_it_be_able_to_get_to_the_vegetable_warehouse=1 + else: + will_it_be_able_to_get_to_the_vegetable_warehouse=0 + + if tractor.gas - cost3 - cost5 > 0: + will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse=1 + else: + will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse = 0 + + is_the_vegetable_warehouse_closed=0 + is_the_gas_station_closed=0 + + + brr = [] + brr.append(can_it_get_to_the_next_point) + brr.append(will_it_be_able_to_get_to_the_gas_station) + brr.append(will_it_be_able_to_get_to_the_gas_station_after_arriving_at_the_next_point) + brr.append(will_it_be_able_to_take_the_next_vegetable_to_the_tractor_storage) + brr.append(will_it_be_able_to_get_to_the_vegetable_warehouse) + brr.append(will_it_be_able_to_get_to_the_gas_station_after_it_arrives_at_the_vegetable_warehouse) + brr.append(is_the_vegetable_warehouse_closed) + brr.append(is_the_gas_station_closed) + + def predict(tree): + if not isinstance(tree, dict): + return tree + else: + root_node = next(iter(tree)) + feature_value = brr[arr.index(root_node)] + if feature_value in tree[root_node]: + return predict(tree[root_node][feature_value]) + else: + return None + decision = predict(tree) + print(decision) + if decision==1: + movement(tractor, grid, b1) + if decision==2: + movement(tractor, grid, b2) + if decision==3: + movement(tractor, grid, b3) + if decision==4: + print("waiting") + if decision==5: + print("GAME OVER") +