GenericAI_Sweeper/learning.py

47 lines
1.6 KiB
Python
Raw Normal View History

from sklearn import tree
import graphviz
class Learning():
X = []
Y = []
clf = tree.DecisionTreeClassifier()
def load_data(self):
file = open('dane.txt', "r")
data_str = []
data_str_separated = []
data = []
for lines in file:
s = file.readline()
data_str.append(s.replace(' \n', ''))
for lines in data_str:
s = lines.split(' ')
data_str_separated.append(s)
for lines in data_str_separated:
line = list(map(int, lines))
data.append(line)
for lines in data:
self.Y.append(lines[8])
lines.pop()
self.X.append(lines)
def learn(self):
#clf = tree.DecisionTreeClassifier()
self.clf = self.clf.fit(self.X, self.Y)
def draw_tree(self):
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("data")
def predict(self, param_array):
print(self.clf.predict([param_array]))
if self.clf.predict([param_array]) == 1:
print("oflagowac")
if self.clf.predict([param_array]) == 2:
print("zdetonowac")
if self.clf.predict([param_array]) == 3:
print("sprzedac na allegro")
if self.clf.predict([param_array]) == 4:
print("sprzedac na czarnym rynku")
if self.clf.predict([param_array]) == 5:
print("obejrzec")