SI_projekt_smieciarka/tree.py
2021-06-23 11:09:17 +02:00

57 lines
1.5 KiB
Python

import joblib
import matplotlib.pyplot as plt
import pandas
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
attributes = ["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space"]
decisions = ["decision"]
def learning_tree():
dataset = pandas.read_csv('./tree_dataset.csv')
x = dataset[attributes]
y = dataset[decisions]
decision_tree = DecisionTreeClassifier()
decision_tree = decision_tree.fit(x, y)
return decision_tree
def making_decision(decision_tree, distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass,
trash_space):
decision = decision_tree.predict(
[[distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass, trash_space]])
return decision
def save_all(decision_tree):
save_tree_to_png(decision_tree)
save_tree_to_txt(decision_tree)
save_tree_to_structure(decision_tree)
def save_tree_to_txt(decision_tree):
with open('./tree_in_txt.txt', "w") as file:
file.write(tree.export_text(decision_tree))
def save_tree_to_png(decision_tree):
fig = plt.figure(figsize=(25, 20))
_ = tree.plot_tree(decision_tree, feature_names=attributes, filled=True)
fig.savefig('./decision_tree.png')
def save_tree_to_structure(decision_tree):
joblib.dump(decision_tree, './tree_model')
def load_tree_from_structure(file):
return joblib.load(file)
if __name__ == '__main__':
tre = learning_tree()
save_all(tre)