71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import time
|
|
from sklearn import tree
|
|
import numpy as np
|
|
import graphviz
|
|
from src.decisionTree.datasetGenerator import generateRawDataset
|
|
from src.decisionTree.datasetConverter import convertDataset
|
|
|
|
|
|
class TreeEngine():
|
|
def __init__(self):
|
|
generateRawDataset()
|
|
convertDataset()
|
|
|
|
# importing the dataset from the disk
|
|
train_data_m = np.genfromtxt(
|
|
"out/dataset.csv", delimiter=",", skip_header=1)
|
|
|
|
# Separate the attributes and labels
|
|
self.X_train = [data[:-1] for data in train_data_m]
|
|
self.y_train = [data[-1] for data in train_data_m]
|
|
|
|
# Create the decision tree classifier using the ID3 algorithm
|
|
self.clf = tree.DecisionTreeClassifier(
|
|
criterion='entropy', splitter="best")
|
|
# clf = tree.DecisionTreeClassifier(criterion='gini')
|
|
|
|
# Train the decision tree on the training data
|
|
self.clf.fit(self.X_train, self.y_train)
|
|
|
|
self.exportText()
|
|
self.exportPdf()
|
|
|
|
def exportText(self):
|
|
# Visualize the trained decision tree
|
|
tree_text = tree.export_text(self.clf, feature_names=[
|
|
"Battery level",
|
|
"Distance between kitchen and table",
|
|
"Customers mood",
|
|
"Basket is empty",
|
|
"Dish is ready",
|
|
"Dish in basket",
|
|
"Table status",
|
|
"Is actual",
|
|
])
|
|
|
|
with open('out/decision_tree.txt', 'w') as f:
|
|
f.write(tree_text) # Save the visualization as a text file
|
|
|
|
def exportPdf(self):
|
|
dot_data = tree.export_graphviz(self.clf, out_file=None, feature_names=[
|
|
"Battery level",
|
|
"Distance between kitchen and table",
|
|
"Customers mood",
|
|
"Basket is empty",
|
|
"Dish is ready",
|
|
"Dish in basket",
|
|
"Table status",
|
|
"Is actual",
|
|
], class_names=[
|
|
'High priority',
|
|
'Low priority',
|
|
'Return to kitchen',
|
|
], filled=True, rounded=True)
|
|
|
|
graph = graphviz.Source(dot_data)
|
|
# Save the visualization as a PDF file
|
|
graph.render("out/decision_tree")
|
|
|
|
def make_predict(self, dataset):
|
|
return self.clf.predict([dataset])
|