implement decision tree saving with joblib library
This commit is contained in:
parent
f0e0d4f4c1
commit
5df7f58b48
5
.gitignore
vendored
5
.gitignore
vendored
@ -3,4 +3,9 @@ __pycache__/
|
|||||||
|
|
||||||
# ignore pdf files
|
# ignore pdf files
|
||||||
*.pdf
|
*.pdf
|
||||||
|
|
||||||
|
# ignore data file
|
||||||
data
|
data
|
||||||
|
|
||||||
|
# ignore .joblib files
|
||||||
|
*.joblib
|
25
learning.py
25
learning.py
@ -1,11 +1,15 @@
|
|||||||
from sklearn import tree
|
from sklearn import tree
|
||||||
import graphviz
|
import graphviz
|
||||||
|
from joblib import dump, load
|
||||||
|
|
||||||
class Learning():
|
class Learning():
|
||||||
|
|
||||||
X = []
|
X = []
|
||||||
Y = []
|
Y = []
|
||||||
clf = tree.DecisionTreeClassifier()
|
clf = tree.DecisionTreeClassifier()
|
||||||
|
saved_clf = None
|
||||||
|
s = None
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
file = open('dane.txt', "r")
|
file = open('dane.txt', "r")
|
||||||
data_str = []
|
data_str = []
|
||||||
@ -27,11 +31,30 @@ class Learning():
|
|||||||
def learn(self):
|
def learn(self):
|
||||||
#clf = tree.DecisionTreeClassifier()
|
#clf = tree.DecisionTreeClassifier()
|
||||||
self.clf = self.clf.fit(self.X, self.Y)
|
self.clf = self.clf.fit(self.X, self.Y)
|
||||||
|
dump(self.clf, 'decision_tree.joblib')
|
||||||
|
|
||||||
|
def load_tree(self):
|
||||||
|
self.saved_clf = load('decision_tree.joblib')
|
||||||
|
|
||||||
def draw_tree(self):
|
def draw_tree(self):
|
||||||
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
|
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
|
||||||
graph = graphviz.Source(dot_data)
|
graph = graphviz.Source(dot_data)
|
||||||
graph.render("data")
|
graph.render("data")
|
||||||
|
|
||||||
|
def predict_on_saved_tree(self, param_array):
|
||||||
|
self.load_tree()
|
||||||
|
print(self.saved_clf.predict([param_array]))
|
||||||
|
if self.saved_clf.predict([param_array]) == 1:
|
||||||
|
print("oflagowac")
|
||||||
|
if self.saved_clf.predict([param_array]) == 2:
|
||||||
|
print("zdetonowac")
|
||||||
|
if self.saved_clf.predict([param_array]) == 3:
|
||||||
|
print("sprzedac na allegro")
|
||||||
|
if self.saved_clf.predict([param_array]) == 4:
|
||||||
|
print("sprzedac na czarnym rynku")
|
||||||
|
if self.saved_clf.predict([param_array]) == 5:
|
||||||
|
print("obejrzec")
|
||||||
|
|
||||||
def predict(self, param_array):
|
def predict(self, param_array):
|
||||||
print(self.clf.predict([param_array]))
|
print(self.clf.predict([param_array]))
|
||||||
if self.clf.predict([param_array]) == 1:
|
if self.clf.predict([param_array]) == 1:
|
||||||
@ -44,3 +67,5 @@ class Learning():
|
|||||||
print("sprzedac na czarnym rynku")
|
print("sprzedac na czarnym rynku")
|
||||||
if self.clf.predict([param_array]) == 5:
|
if self.clf.predict([param_array]) == 5:
|
||||||
print("obejrzec")
|
print("obejrzec")
|
||||||
|
|
||||||
|
|
||||||
|
6
main.py
6
main.py
@ -18,6 +18,7 @@ from pizza import *
|
|||||||
from learning import *
|
from learning import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pg.init()
|
pg.init()
|
||||||
@ -47,7 +48,7 @@ class Game:
|
|||||||
for col, tile in enumerate(tiles):
|
for col, tile in enumerate(tiles):
|
||||||
if tile == '2':
|
if tile == '2':
|
||||||
Mine(self, col, row)
|
Mine(self, col, row)
|
||||||
Mine.set_parameters(Mine,-5,32,6,7,1,0,0,0)
|
Mine.set_parameters(Mine,18,32,6,7,0,0,0,0)
|
||||||
if tile == '3':
|
if tile == '3':
|
||||||
Bomb(self, col, row)
|
Bomb(self, col, row)
|
||||||
if tile == '4':
|
if tile == '4':
|
||||||
@ -130,7 +131,8 @@ class Game:
|
|||||||
|
|
||||||
# Test.run()
|
# Test.run()
|
||||||
if event.key == pg.K_F5:
|
if event.key == pg.K_F5:
|
||||||
print("lol xD")
|
self.player.decision_tree_learning()
|
||||||
|
#print("lol xD")
|
||||||
pg.event.clear()
|
pg.event.clear()
|
||||||
if event.key == pg.K_F6:
|
if event.key == pg.K_F6:
|
||||||
pg.event.clear()
|
pg.event.clear()
|
||||||
|
12
sprites.py
12
sprites.py
@ -6,6 +6,7 @@ from settings import *
|
|||||||
from maze import *
|
from maze import *
|
||||||
from learning import *
|
from learning import *
|
||||||
|
|
||||||
|
|
||||||
class Player(pg.sprite.Sprite):
|
class Player(pg.sprite.Sprite):
|
||||||
def __init__(self, game, x, y, direction = 'Right'):
|
def __init__(self, game, x, y, direction = 'Right'):
|
||||||
self.groups = game.all_sprites
|
self.groups = game.all_sprites
|
||||||
@ -24,7 +25,7 @@ class Player(pg.sprite.Sprite):
|
|||||||
self.maze = Maze()
|
self.maze = Maze()
|
||||||
self.moves = ''
|
self.moves = ''
|
||||||
self.my_learning = Learning()
|
self.my_learning = Learning()
|
||||||
self.decision_tree_learning()
|
#self.decision_tree_learning()
|
||||||
|
|
||||||
def set_direction(self, direction):
|
def set_direction(self, direction):
|
||||||
self.direction = direction
|
self.direction = direction
|
||||||
@ -94,6 +95,9 @@ class Player(pg.sprite.Sprite):
|
|||||||
self.check_bomb()
|
self.check_bomb()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def check_border(self, dx=0, dy=0):
|
def check_border(self, dx=0, dy=0):
|
||||||
if (self.x + dx) < 0 or (self.y + dy) < 0 or (self.x + dx) >= MAP_SIZE or (self.y + dy) >= MAP_SIZE :
|
if (self.x + dx) < 0 or (self.y + dy) < 0 or (self.x + dx) >= MAP_SIZE or (self.y + dy) >= MAP_SIZE :
|
||||||
return False
|
return False
|
||||||
@ -188,14 +192,18 @@ class Player(pg.sprite.Sprite):
|
|||||||
self.my_learning.load_data()
|
self.my_learning.load_data()
|
||||||
self.my_learning.learn()
|
self.my_learning.learn()
|
||||||
self.my_learning.draw_tree()
|
self.my_learning.draw_tree()
|
||||||
|
print("new decision tree created")
|
||||||
|
print("restart to use saved decision tree")
|
||||||
#my_learning.predict()
|
#my_learning.predict()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
""" sprawdzenie danych miny """
|
""" sprawdzenie danych miny """
|
||||||
def check_bomb(self):
|
def check_bomb(self):
|
||||||
if self.check_if_on_mine():
|
if self.check_if_on_mine():
|
||||||
current_mine = self.get_my_mine_object()
|
current_mine = self.get_my_mine_object()
|
||||||
mine_params = current_mine.get_parameters()
|
mine_params = current_mine.get_parameters()
|
||||||
self.my_learning.predict(mine_params)
|
self.my_learning.predict_on_saved_tree(mine_params)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user