From 5df7f58b48fc7fd5cc93c21e1313aaca935b4f42 Mon Sep 17 00:00:00 2001 From: eugenep Date: Tue, 18 May 2021 14:32:53 +0200 Subject: [PATCH] implement decision tree saving with joblib library --- .gitignore | 7 ++++++- learning.py | 25 +++++++++++++++++++++++++ main.py | 6 ++++-- sprites.py | 12 ++++++++++-- 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index a6cc740..732a48e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,9 @@ __pycache__/ # ignore pdf files *.pdf -data \ No newline at end of file + +# ignore data file +data + +# ignore .joblib files +*.joblib \ No newline at end of file diff --git a/learning.py b/learning.py index e988eee..d6c3b9e 100644 --- a/learning.py +++ b/learning.py @@ -1,11 +1,15 @@ from sklearn import tree import graphviz +from joblib import dump, load class Learning(): X = [] Y = [] clf = tree.DecisionTreeClassifier() + saved_clf = None + s = None + def load_data(self): file = open('dane.txt', "r") data_str = [] @@ -27,11 +31,30 @@ class Learning(): def learn(self): #clf = tree.DecisionTreeClassifier() self.clf = self.clf.fit(self.X, self.Y) + dump(self.clf, 'decision_tree.joblib') + + def load_tree(self): + self.saved_clf = load('decision_tree.joblib') + def draw_tree(self): dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True) graph = graphviz.Source(dot_data) graph.render("data") + def predict_on_saved_tree(self, param_array): + self.load_tree() + print(self.saved_clf.predict([param_array])) + if self.saved_clf.predict([param_array]) == 1: + print("oflagowac") + if self.saved_clf.predict([param_array]) == 2: + print("zdetonowac") + if self.saved_clf.predict([param_array]) == 3: + print("sprzedac na allegro") + if self.saved_clf.predict([param_array]) == 4: + print("sprzedac na czarnym rynku") + if self.saved_clf.predict([param_array]) == 5: + print("obejrzec") + def predict(self, param_array): print(self.clf.predict([param_array])) if self.clf.predict([param_array]) == 1: @@ -44,3 +67,5 @@ class Learning(): print("sprzedac na czarnym rynku") if self.clf.predict([param_array]) == 5: print("obejrzec") + + diff --git a/main.py b/main.py index 6ff026a..4054cfa 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,7 @@ from pizza import * from learning import * + class Game: def __init__(self): pg.init() @@ -47,7 +48,7 @@ class Game: for col, tile in enumerate(tiles): if tile == '2': Mine(self, col, row) - Mine.set_parameters(Mine,-5,32,6,7,1,0,0,0) + Mine.set_parameters(Mine,18,32,6,7,0,0,0,0) if tile == '3': Bomb(self, col, row) if tile == '4': @@ -130,7 +131,8 @@ class Game: # Test.run() if event.key == pg.K_F5: - print("lol xD") + self.player.decision_tree_learning() + #print("lol xD") pg.event.clear() if event.key == pg.K_F6: pg.event.clear() diff --git a/sprites.py b/sprites.py index 2263ca9..c81590f 100644 --- a/sprites.py +++ b/sprites.py @@ -6,6 +6,7 @@ from settings import * from maze import * from learning import * + class Player(pg.sprite.Sprite): def __init__(self, game, x, y, direction = 'Right'): self.groups = game.all_sprites @@ -24,7 +25,7 @@ class Player(pg.sprite.Sprite): self.maze = Maze() self.moves = '' self.my_learning = Learning() - self.decision_tree_learning() + #self.decision_tree_learning() def set_direction(self, direction): self.direction = direction @@ -92,6 +93,9 @@ class Player(pg.sprite.Sprite): print("My direction is: " + str(self.direction)) self.check_bomb() + + + def check_border(self, dx=0, dy=0): @@ -188,14 +192,18 @@ class Player(pg.sprite.Sprite): self.my_learning.load_data() self.my_learning.learn() self.my_learning.draw_tree() + print("new decision tree created") + print("restart to use saved decision tree") #my_learning.predict() + + """ sprawdzenie danych miny """ def check_bomb(self): if self.check_if_on_mine(): current_mine = self.get_my_mine_object() mine_params = current_mine.get_parameters() - self.my_learning.predict(mine_params) + self.my_learning.predict_on_saved_tree(mine_params) return