import joblib import matplotlib.pyplot as plt import pandas from sklearn import tree from sklearn.tree import DecisionTreeClassifier attributes = ["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space"] decisions = ["decision"] def learning_tree(): dataset = pandas.read_csv('./tree_dataset.csv') x = dataset[attributes] y = dataset[decisions] decision_tree = DecisionTreeClassifier() decision_tree = decision_tree.fit(x, y) return decision_tree def making_decision(decision_tree, distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass, trash_space): decision = decision_tree.predict( [[distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass, trash_space]]) return decision def save_all(decision_tree): save_tree_to_png(decision_tree) save_tree_to_txt(decision_tree) save_tree_to_structure(decision_tree) def save_tree_to_txt(decision_tree): with open('./tree_in_txt.txt', "w") as file: file.write(tree.export_text(decision_tree)) def save_tree_to_png(decision_tree): fig = plt.figure(figsize=(25, 20)) _ = tree.plot_tree(decision_tree, feature_names=attributes, filled=True) fig.savefig('./decision_tree.png') def save_tree_to_structure(decision_tree): joblib.dump(decision_tree, './tree_model') def load_tree_from_structure(file): return joblib.load(file) if __name__ == '__main__': tre = learning_tree() save_all(tre)