implement decision tree saving with joblib library

This commit is contained in:
eugenep 2021-05-18 14:32:53 +02:00
parent f0e0d4f4c1
commit 5df7f58b48
4 changed files with 45 additions and 5 deletions

5
.gitignore vendored
View File

@ -3,4 +3,9 @@ __pycache__/
# ignore pdf files # ignore pdf files
*.pdf *.pdf
# ignore data file
data data
# ignore .joblib files
*.joblib

View File

@ -1,11 +1,15 @@
from sklearn import tree from sklearn import tree
import graphviz import graphviz
from joblib import dump, load
class Learning(): class Learning():
X = [] X = []
Y = [] Y = []
clf = tree.DecisionTreeClassifier() clf = tree.DecisionTreeClassifier()
saved_clf = None
s = None
def load_data(self): def load_data(self):
file = open('dane.txt', "r") file = open('dane.txt', "r")
data_str = [] data_str = []
@ -27,11 +31,30 @@ class Learning():
def learn(self): def learn(self):
#clf = tree.DecisionTreeClassifier() #clf = tree.DecisionTreeClassifier()
self.clf = self.clf.fit(self.X, self.Y) self.clf = self.clf.fit(self.X, self.Y)
dump(self.clf, 'decision_tree.joblib')
def load_tree(self):
self.saved_clf = load('decision_tree.joblib')
def draw_tree(self): def draw_tree(self):
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True) dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
graph = graphviz.Source(dot_data) graph = graphviz.Source(dot_data)
graph.render("data") graph.render("data")
def predict_on_saved_tree(self, param_array):
self.load_tree()
print(self.saved_clf.predict([param_array]))
if self.saved_clf.predict([param_array]) == 1:
print("oflagowac")
if self.saved_clf.predict([param_array]) == 2:
print("zdetonowac")
if self.saved_clf.predict([param_array]) == 3:
print("sprzedac na allegro")
if self.saved_clf.predict([param_array]) == 4:
print("sprzedac na czarnym rynku")
if self.saved_clf.predict([param_array]) == 5:
print("obejrzec")
def predict(self, param_array): def predict(self, param_array):
print(self.clf.predict([param_array])) print(self.clf.predict([param_array]))
if self.clf.predict([param_array]) == 1: if self.clf.predict([param_array]) == 1:
@ -44,3 +67,5 @@ class Learning():
print("sprzedac na czarnym rynku") print("sprzedac na czarnym rynku")
if self.clf.predict([param_array]) == 5: if self.clf.predict([param_array]) == 5:
print("obejrzec") print("obejrzec")

View File

@ -18,6 +18,7 @@ from pizza import *
from learning import * from learning import *
class Game: class Game:
def __init__(self): def __init__(self):
pg.init() pg.init()
@ -47,7 +48,7 @@ class Game:
for col, tile in enumerate(tiles): for col, tile in enumerate(tiles):
if tile == '2': if tile == '2':
Mine(self, col, row) Mine(self, col, row)
Mine.set_parameters(Mine,-5,32,6,7,1,0,0,0) Mine.set_parameters(Mine,18,32,6,7,0,0,0,0)
if tile == '3': if tile == '3':
Bomb(self, col, row) Bomb(self, col, row)
if tile == '4': if tile == '4':
@ -130,7 +131,8 @@ class Game:
# Test.run() # Test.run()
if event.key == pg.K_F5: if event.key == pg.K_F5:
print("lol xD") self.player.decision_tree_learning()
#print("lol xD")
pg.event.clear() pg.event.clear()
if event.key == pg.K_F6: if event.key == pg.K_F6:
pg.event.clear() pg.event.clear()

View File

@ -6,6 +6,7 @@ from settings import *
from maze import * from maze import *
from learning import * from learning import *
class Player(pg.sprite.Sprite): class Player(pg.sprite.Sprite):
def __init__(self, game, x, y, direction = 'Right'): def __init__(self, game, x, y, direction = 'Right'):
self.groups = game.all_sprites self.groups = game.all_sprites
@ -24,7 +25,7 @@ class Player(pg.sprite.Sprite):
self.maze = Maze() self.maze = Maze()
self.moves = '' self.moves = ''
self.my_learning = Learning() self.my_learning = Learning()
self.decision_tree_learning() #self.decision_tree_learning()
def set_direction(self, direction): def set_direction(self, direction):
self.direction = direction self.direction = direction
@ -94,6 +95,9 @@ class Player(pg.sprite.Sprite):
self.check_bomb() self.check_bomb()
def check_border(self, dx=0, dy=0): def check_border(self, dx=0, dy=0):
if (self.x + dx) < 0 or (self.y + dy) < 0 or (self.x + dx) >= MAP_SIZE or (self.y + dy) >= MAP_SIZE : if (self.x + dx) < 0 or (self.y + dy) < 0 or (self.x + dx) >= MAP_SIZE or (self.y + dy) >= MAP_SIZE :
return False return False
@ -188,14 +192,18 @@ class Player(pg.sprite.Sprite):
self.my_learning.load_data() self.my_learning.load_data()
self.my_learning.learn() self.my_learning.learn()
self.my_learning.draw_tree() self.my_learning.draw_tree()
print("new decision tree created")
print("restart to use saved decision tree")
#my_learning.predict() #my_learning.predict()
""" sprawdzenie danych miny """ """ sprawdzenie danych miny """
def check_bomb(self): def check_bomb(self):
if self.check_if_on_mine(): if self.check_if_on_mine():
current_mine = self.get_my_mine_object() current_mine = self.get_my_mine_object()
mine_params = current_mine.get_parameters() mine_params = current_mine.get_parameters()
self.my_learning.predict(mine_params) self.my_learning.predict_on_saved_tree(mine_params)
return return