Improve performance of decision tree

This commit is contained in:
egotd 2022-05-19 12:26:48 +02:00
parent 89bef876b4
commit 135ab3e9c6
5 changed files with 19 additions and 11 deletions

View File

@ -29,7 +29,7 @@ class AI:
pass pass
#co ma zrobić przy każdym ruchu <------------------------- najważniejsze #co ma zrobić przy każdym ruchu <------------------------- najważniejsze
def updateTile(self): def updateTile(self, model):
#aktualne pola (do debugu) #aktualne pola (do debugu)
sensor = self.saper.sensor() sensor = self.saper.sensor()
@ -40,7 +40,7 @@ class AI:
#podniesienie bomby jeśli jest jakaś na tym polu #podniesienie bomby jeśli jest jakaś na tym polu
self.saper.pick_up() self.saper.pick_up(model)
#poruszenie się #poruszenie się
if self.user_controlled: if self.user_controlled:

View File

@ -164,7 +164,7 @@ class BFS:
# jesli tmp node to goaltest # jesli tmp node to goaltest
if tmp_node_position[:2] == goaltest: if tmp_node_position[:2] == goaltest:
print('Find') print('Find\n')
while tmp_node[1].get_parent() is not None: while tmp_node[1].get_parent() is not None:
final_action_list.append(tmp_node[1].get_action()) final_action_list.append(tmp_node[1].get_action())
@ -177,7 +177,7 @@ class BFS:
explored.append(tmp_node[1]) # add node to array of visited nodes explored.append(tmp_node[1]) # add node to array of visited nodes
neighbours_list_of_our_node = self.successor(tmp_node_position) # lista możliwych akcij neighbours_list_of_our_node = self.successor(tmp_node_position) # lista możliwych akcij
print(neighbours_list_of_our_node) # print(neighbours_list_of_our_node)
for node_ in neighbours_list_of_our_node: for node_ in neighbours_list_of_our_node:
# node_ is tuple(action, [x, y, gdzie_patczy], cost) # node_ is tuple(action, [x, y, gdzie_patczy], cost)

View File

@ -8,7 +8,11 @@ from numpy import random
############### ###############
class DecisionTrees: class DecisionTrees:
def return_predict(self): def create_model(self):
model = chef.fit(pd.read_csv("D:\\1 Python projects\Saper\data\db.txt"), {'algorithm': 'ID3'})
return model
def return_predict(self, mod):
# read data # read data
df = pd.read_csv("D:\\1 Python projects\Saper\data\db.txt") df = pd.read_csv("D:\\1 Python projects\Saper\data\db.txt")
@ -24,7 +28,6 @@ class DecisionTrees:
# ID3 config # ID3 config
config = {'algorithm': 'ID3'} config = {'algorithm': 'ID3'}
# create decision tree # create decision tree
model = chef.fit(df, config)
# print predict # print predict
# print(chef.predict(model, [1, 2022, 0, 0, 0, 10])) # print(chef.predict(model, [1, 2022, 0, 0, 0, 10]))
@ -52,4 +55,4 @@ class DecisionTrees:
cnt += 1 cnt += 1
# return prediction # return prediction
return chef.predict(model, mine_characteristics) return chef.predict(mod, mine_characteristics)

View File

@ -389,15 +389,15 @@ class Minesweeper:
pygame.mixer.Channel(2).set_volume(0.5) pygame.mixer.Channel(2).set_volume(0.5)
pygame.mixer.Channel(2).play(pygame.mixer.Sound("assets/sounds/collision.wav")) pygame.mixer.Channel(2).play(pygame.mixer.Sound("assets/sounds/collision.wav"))
def pick_up(self): def pick_up(self, model):
if self.offset_x != 0 or self.offset_y != 0: if self.offset_x != 0 or self.offset_y != 0:
return return
for mine in self.current_map.mines: for mine in self.current_map.mines:
if (self.position_x, self.position_y) == (mine.position_x, mine.position_y): if (self.position_x, self.position_y) == (mine.position_x, mine.position_y):
tree = decisionTrees.DecisionTrees() tree = decisionTrees.DecisionTrees()
decision = tree.return_predict() decision = tree.return_predict(model)
print("Decision : ", decision) print("Decision : ", decision, "\n")
self.current_map.mines.remove(mine) self.current_map.mines.remove(mine)
pygame.mixer.Channel(3).set_volume(0.7) pygame.mixer.Channel(3).set_volume(0.7)

View File

@ -58,6 +58,11 @@ def main():
# główna pętla # główna pętla
game_loop = True game_loop = True
clock = pygame.time.Clock() clock = pygame.time.Clock()
# create decision tree
tree = decisionTrees.DecisionTrees()
model = tree.create_model()
while game_loop: while game_loop:
# wdrożenie FPS, delta - czas od ostatniej klatki # wdrożenie FPS, delta - czas od ostatniej klatki
delta = clock.tick(FPS) delta = clock.tick(FPS)
@ -66,7 +71,7 @@ def main():
AI.updateFPS() AI.updateFPS()
if saper.offset_x == 0 and saper.offset_y == 0: if saper.offset_x == 0 and saper.offset_y == 0:
AI.updateTile() AI.updateTile(model)
# narysowanie terenu i obiektów # narysowanie terenu i obiektów
map.draw_tiles() map.draw_tiles()