44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import matplotlib.image as pltimg
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
import pandas
|
|
import pydotplus
|
|
from sklearn import tree
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
|
|
|
|
|
|
def treelearn():
|
|
dataframe = pandas.read_csv('./resources/dataset.csv', delimiter=';', header=None, skiprows=1, names=["label", "size", "weight", "urgent", "weekend", "payment_method", "international", "delayed", "sector"])
|
|
|
|
attributes = ["label", "size", "weight", "urgent", "weekend", "payment_method", "international", "delayed"]
|
|
x = dataframe[attributes]
|
|
y = dataframe['sector']
|
|
|
|
#dataframe[["label", ]]=dataframe["label"].str.split(";", expand=True)
|
|
dataframe = dataframe['label'].str.split(';', n=8, expand=True)
|
|
x = dataframe.iloc[:, 0:8]
|
|
y = dataframe.iloc[:, -1:]
|
|
print(attributes)
|
|
# print(y)
|
|
decision_tree = DecisionTreeClassifier()
|
|
decision_tree = decision_tree.fit(x, y)
|
|
|
|
# visualize and display decision tree
|
|
data = tree.export_graphviz(decision_tree, out_file=None, feature_names=attributes)
|
|
graph = pydotplus.graph_from_dot_data(data)
|
|
# graph.write_png(os.path.join('resources', 'decision_tree.png'))
|
|
img = pltimg.imread(os.path.join('resources', 'decision_tree.png'))
|
|
imgplot = plt.imshow(img)
|
|
|
|
plt.show()
|
|
|
|
return decision_tree
|
|
|
|
|
|
def make_decision(tree, label, size, weight, urgent, weekend, payment_method, international,
|
|
delayed):
|
|
decision = tree.predict([[label, size, weight, urgent, weekend, payment_method, international,
|
|
delayed]])
|
|
return decision
|