implement decision tree saving with joblib library

This commit is contained in:
eugenep 2021-05-18 14:32:53 +02:00
parent f0e0d4f4c1
commit 5df7f58b48
4 changed files with 45 additions and 5 deletions

7
.gitignore vendored
View File

@ -3,4 +3,9 @@ __pycache__/
# ignore pdf files
*.pdf
data
# ignore data file
data
# ignore .joblib files
*.joblib

View File

@ -1,11 +1,15 @@
from sklearn import tree
import graphviz
from joblib import dump, load
class Learning():
X = []
Y = []
clf = tree.DecisionTreeClassifier()
saved_clf = None
s = None
def load_data(self):
file = open('dane.txt', "r")
data_str = []
@ -27,11 +31,30 @@ class Learning():
def learn(self):
#clf = tree.DecisionTreeClassifier()
self.clf = self.clf.fit(self.X, self.Y)
dump(self.clf, 'decision_tree.joblib')
def load_tree(self):
self.saved_clf = load('decision_tree.joblib')
def draw_tree(self):
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("data")
def predict_on_saved_tree(self, param_array):
self.load_tree()
print(self.saved_clf.predict([param_array]))
if self.saved_clf.predict([param_array]) == 1:
print("oflagowac")
if self.saved_clf.predict([param_array]) == 2:
print("zdetonowac")
if self.saved_clf.predict([param_array]) == 3:
print("sprzedac na allegro")
if self.saved_clf.predict([param_array]) == 4:
print("sprzedac na czarnym rynku")
if self.saved_clf.predict([param_array]) == 5:
print("obejrzec")
def predict(self, param_array):
print(self.clf.predict([param_array]))
if self.clf.predict([param_array]) == 1:
@ -44,3 +67,5 @@ class Learning():
print("sprzedac na czarnym rynku")
if self.clf.predict([param_array]) == 5:
print("obejrzec")

View File

@ -18,6 +18,7 @@ from pizza import *
from learning import *
class Game:
def __init__(self):
pg.init()
@ -47,7 +48,7 @@ class Game:
for col, tile in enumerate(tiles):
if tile == '2':
Mine(self, col, row)
Mine.set_parameters(Mine,-5,32,6,7,1,0,0,0)
Mine.set_parameters(Mine,18,32,6,7,0,0,0,0)
if tile == '3':
Bomb(self, col, row)
if tile == '4':
@ -130,7 +131,8 @@ class Game:
# Test.run()
if event.key == pg.K_F5:
print("lol xD")
self.player.decision_tree_learning()
#print("lol xD")
pg.event.clear()
if event.key == pg.K_F6:
pg.event.clear()

View File

@ -6,6 +6,7 @@ from settings import *
from maze import *
from learning import *
class Player(pg.sprite.Sprite):
def __init__(self, game, x, y, direction = 'Right'):
self.groups = game.all_sprites
@ -24,7 +25,7 @@ class Player(pg.sprite.Sprite):
self.maze = Maze()
self.moves = ''
self.my_learning = Learning()
self.decision_tree_learning()
#self.decision_tree_learning()
def set_direction(self, direction):
self.direction = direction
@ -92,6 +93,9 @@ class Player(pg.sprite.Sprite):
print("My direction is: " + str(self.direction))
self.check_bomb()
def check_border(self, dx=0, dy=0):
@ -188,14 +192,18 @@ class Player(pg.sprite.Sprite):
self.my_learning.load_data()
self.my_learning.learn()
self.my_learning.draw_tree()
print("new decision tree created")
print("restart to use saved decision tree")
#my_learning.predict()
""" sprawdzenie danych miny """
def check_bomb(self):
if self.check_if_on_mine():
current_mine = self.get_my_mine_object()
mine_params = current_mine.get_parameters()
self.my_learning.predict(mine_params)
self.my_learning.predict_on_saved_tree(mine_params)
return