47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
|
from sklearn import tree
|
||
|
import graphviz
|
||
|
|
||
|
class Learning():
|
||
|
|
||
|
X = []
|
||
|
Y = []
|
||
|
clf = tree.DecisionTreeClassifier()
|
||
|
def load_data(self):
|
||
|
file = open('dane.txt', "r")
|
||
|
data_str = []
|
||
|
data_str_separated = []
|
||
|
data = []
|
||
|
for lines in file:
|
||
|
s = file.readline()
|
||
|
data_str.append(s.replace(' \n', ''))
|
||
|
for lines in data_str:
|
||
|
s = lines.split(' ')
|
||
|
data_str_separated.append(s)
|
||
|
for lines in data_str_separated:
|
||
|
line = list(map(int, lines))
|
||
|
data.append(line)
|
||
|
for lines in data:
|
||
|
self.Y.append(lines[8])
|
||
|
lines.pop()
|
||
|
self.X.append(lines)
|
||
|
def learn(self):
|
||
|
#clf = tree.DecisionTreeClassifier()
|
||
|
self.clf = self.clf.fit(self.X, self.Y)
|
||
|
def draw_tree(self):
|
||
|
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
|
||
|
graph = graphviz.Source(dot_data)
|
||
|
graph.render("data")
|
||
|
|
||
|
def predict(self, param_array):
|
||
|
print(self.clf.predict([param_array]))
|
||
|
if self.clf.predict([param_array]) == 1:
|
||
|
print("oflagowac")
|
||
|
if self.clf.predict([param_array]) == 2:
|
||
|
print("zdetonowac")
|
||
|
if self.clf.predict([param_array]) == 3:
|
||
|
print("sprzedac na allegro")
|
||
|
if self.clf.predict([param_array]) == 4:
|
||
|
print("sprzedac na czarnym rynku")
|
||
|
if self.clf.predict([param_array]) == 5:
|
||
|
print("obejrzec")
|