72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
from sklearn import tree
|
|
import graphviz
|
|
from joblib import dump, load
|
|
|
|
class Learning():
|
|
|
|
X = []
|
|
Y = []
|
|
clf = tree.DecisionTreeClassifier()
|
|
saved_clf = None
|
|
s = None
|
|
|
|
def load_data(self):
|
|
file = open('dane.txt', "r")
|
|
data_str = []
|
|
data_str_separated = []
|
|
data = []
|
|
for lines in file:
|
|
s = file.readline()
|
|
data_str.append(s.replace(' \n', ''))
|
|
for lines in data_str:
|
|
s = lines.split(' ')
|
|
data_str_separated.append(s)
|
|
for lines in data_str_separated:
|
|
line = list(map(int, lines))
|
|
data.append(line)
|
|
for lines in data:
|
|
self.Y.append(lines[8])
|
|
lines.pop()
|
|
self.X.append(lines)
|
|
def learn(self):
|
|
#clf = tree.DecisionTreeClassifier()
|
|
self.clf = self.clf.fit(self.X, self.Y)
|
|
dump(self.clf, 'decision_tree.joblib')
|
|
|
|
def load_tree(self):
|
|
self.saved_clf = load('decision_tree.joblib')
|
|
|
|
def draw_tree(self):
|
|
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
|
|
graph = graphviz.Source(dot_data)
|
|
graph.render("data")
|
|
|
|
def predict_on_saved_tree(self, param_array):
|
|
self.load_tree()
|
|
print(self.saved_clf.predict([param_array]))
|
|
if self.saved_clf.predict([param_array]) == 1:
|
|
print("oflagowac")
|
|
if self.saved_clf.predict([param_array]) == 2:
|
|
print("zdetonowac")
|
|
if self.saved_clf.predict([param_array]) == 3:
|
|
print("sprzedac na allegro")
|
|
if self.saved_clf.predict([param_array]) == 4:
|
|
print("sprzedac na czarnym rynku")
|
|
if self.saved_clf.predict([param_array]) == 5:
|
|
print("obejrzec")
|
|
|
|
def predict(self, param_array):
|
|
print(self.clf.predict([param_array]))
|
|
if self.clf.predict([param_array]) == 1:
|
|
print("oflagowac")
|
|
if self.clf.predict([param_array]) == 2:
|
|
print("zdetonowac")
|
|
if self.clf.predict([param_array]) == 3:
|
|
print("sprzedac na allegro")
|
|
if self.clf.predict([param_array]) == 4:
|
|
print("sprzedac na czarnym rynku")
|
|
if self.clf.predict([param_array]) == 5:
|
|
print("obejrzec")
|
|
|
|
|