GenericAI_Sweeper/learning.py

72 lines
2.3 KiB
Python

from sklearn import tree
import graphviz
from joblib import dump, load
class Learning():
X = []
Y = []
clf = tree.DecisionTreeClassifier()
saved_clf = None
s = None
def load_data(self):
file = open('dane.txt', "r")
data_str = []
data_str_separated = []
data = []
for lines in file:
s = file.readline()
data_str.append(s.replace(' \n', ''))
for lines in data_str:
s = lines.split(' ')
data_str_separated.append(s)
for lines in data_str_separated:
line = list(map(int, lines))
data.append(line)
for lines in data:
self.Y.append(lines[8])
lines.pop()
self.X.append(lines)
def learn(self):
#clf = tree.DecisionTreeClassifier()
self.clf = self.clf.fit(self.X, self.Y)
dump(self.clf, 'decision_tree.joblib')
def load_tree(self):
self.saved_clf = load('decision_tree.joblib')
def draw_tree(self):
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("data")
def predict_on_saved_tree(self, param_array):
self.load_tree()
print(self.saved_clf.predict([param_array]))
if self.saved_clf.predict([param_array]) == 1:
print("oflagowac")
if self.saved_clf.predict([param_array]) == 2:
print("zdetonowac")
if self.saved_clf.predict([param_array]) == 3:
print("sprzedac na allegro")
if self.saved_clf.predict([param_array]) == 4:
print("sprzedac na czarnym rynku")
if self.saved_clf.predict([param_array]) == 5:
print("obejrzec")
def predict(self, param_array):
print(self.clf.predict([param_array]))
if self.clf.predict([param_array]) == 1:
print("oflagowac")
if self.clf.predict([param_array]) == 2:
print("zdetonowac")
if self.clf.predict([param_array]) == 3:
print("sprzedac na allegro")
if self.clf.predict([param_array]) == 4:
print("sprzedac na czarnym rynku")
if self.clf.predict([param_array]) == 5:
print("obejrzec")