implement decision tree saving with joblib library
This commit is contained in:
parent
f0e0d4f4c1
commit
5df7f58b48
5
.gitignore
vendored
5
.gitignore
vendored
@ -3,4 +3,9 @@ __pycache__/
|
||||
|
||||
# ignore pdf files
|
||||
*.pdf
|
||||
|
||||
# ignore data file
|
||||
data
|
||||
|
||||
# ignore .joblib files
|
||||
*.joblib
|
25
learning.py
25
learning.py
@ -1,11 +1,15 @@
|
||||
from sklearn import tree
|
||||
import graphviz
|
||||
from joblib import dump, load
|
||||
|
||||
class Learning():
|
||||
|
||||
X = []
|
||||
Y = []
|
||||
clf = tree.DecisionTreeClassifier()
|
||||
saved_clf = None
|
||||
s = None
|
||||
|
||||
def load_data(self):
|
||||
file = open('dane.txt', "r")
|
||||
data_str = []
|
||||
@ -27,11 +31,30 @@ class Learning():
|
||||
def learn(self):
|
||||
#clf = tree.DecisionTreeClassifier()
|
||||
self.clf = self.clf.fit(self.X, self.Y)
|
||||
dump(self.clf, 'decision_tree.joblib')
|
||||
|
||||
def load_tree(self):
|
||||
self.saved_clf = load('decision_tree.joblib')
|
||||
|
||||
def draw_tree(self):
|
||||
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
|
||||
graph = graphviz.Source(dot_data)
|
||||
graph.render("data")
|
||||
|
||||
def predict_on_saved_tree(self, param_array):
|
||||
self.load_tree()
|
||||
print(self.saved_clf.predict([param_array]))
|
||||
if self.saved_clf.predict([param_array]) == 1:
|
||||
print("oflagowac")
|
||||
if self.saved_clf.predict([param_array]) == 2:
|
||||
print("zdetonowac")
|
||||
if self.saved_clf.predict([param_array]) == 3:
|
||||
print("sprzedac na allegro")
|
||||
if self.saved_clf.predict([param_array]) == 4:
|
||||
print("sprzedac na czarnym rynku")
|
||||
if self.saved_clf.predict([param_array]) == 5:
|
||||
print("obejrzec")
|
||||
|
||||
def predict(self, param_array):
|
||||
print(self.clf.predict([param_array]))
|
||||
if self.clf.predict([param_array]) == 1:
|
||||
@ -44,3 +67,5 @@ class Learning():
|
||||
print("sprzedac na czarnym rynku")
|
||||
if self.clf.predict([param_array]) == 5:
|
||||
print("obejrzec")
|
||||
|
||||
|
||||
|
6
main.py
6
main.py
@ -18,6 +18,7 @@ from pizza import *
|
||||
from learning import *
|
||||
|
||||
|
||||
|
||||
class Game:
|
||||
def __init__(self):
|
||||
pg.init()
|
||||
@ -47,7 +48,7 @@ class Game:
|
||||
for col, tile in enumerate(tiles):
|
||||
if tile == '2':
|
||||
Mine(self, col, row)
|
||||
Mine.set_parameters(Mine,-5,32,6,7,1,0,0,0)
|
||||
Mine.set_parameters(Mine,18,32,6,7,0,0,0,0)
|
||||
if tile == '3':
|
||||
Bomb(self, col, row)
|
||||
if tile == '4':
|
||||
@ -130,7 +131,8 @@ class Game:
|
||||
|
||||
# Test.run()
|
||||
if event.key == pg.K_F5:
|
||||
print("lol xD")
|
||||
self.player.decision_tree_learning()
|
||||
#print("lol xD")
|
||||
pg.event.clear()
|
||||
if event.key == pg.K_F6:
|
||||
pg.event.clear()
|
||||
|
12
sprites.py
12
sprites.py
@ -6,6 +6,7 @@ from settings import *
|
||||
from maze import *
|
||||
from learning import *
|
||||
|
||||
|
||||
class Player(pg.sprite.Sprite):
|
||||
def __init__(self, game, x, y, direction = 'Right'):
|
||||
self.groups = game.all_sprites
|
||||
@ -24,7 +25,7 @@ class Player(pg.sprite.Sprite):
|
||||
self.maze = Maze()
|
||||
self.moves = ''
|
||||
self.my_learning = Learning()
|
||||
self.decision_tree_learning()
|
||||
#self.decision_tree_learning()
|
||||
|
||||
def set_direction(self, direction):
|
||||
self.direction = direction
|
||||
@ -94,6 +95,9 @@ class Player(pg.sprite.Sprite):
|
||||
self.check_bomb()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def check_border(self, dx=0, dy=0):
|
||||
if (self.x + dx) < 0 or (self.y + dy) < 0 or (self.x + dx) >= MAP_SIZE or (self.y + dy) >= MAP_SIZE :
|
||||
return False
|
||||
@ -188,14 +192,18 @@ class Player(pg.sprite.Sprite):
|
||||
self.my_learning.load_data()
|
||||
self.my_learning.learn()
|
||||
self.my_learning.draw_tree()
|
||||
print("new decision tree created")
|
||||
print("restart to use saved decision tree")
|
||||
#my_learning.predict()
|
||||
|
||||
|
||||
|
||||
""" sprawdzenie danych miny """
|
||||
def check_bomb(self):
|
||||
if self.check_if_on_mine():
|
||||
current_mine = self.get_my_mine_object()
|
||||
mine_params = current_mine.get_parameters()
|
||||
self.my_learning.predict(mine_params)
|
||||
self.my_learning.predict_on_saved_tree(mine_params)
|
||||
return
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user