wozek-projekt/decision_tree.py

44 lines
1.6 KiB
Python

import matplotlib.image as pltimg
import matplotlib.pyplot as plt
import os
import pandas
import pydotplus
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
def treelearn():
dataframe = pandas.read_csv('./resources/dataset.csv', delimiter=';', header=None, skiprows=1, names=["label", "size", "weight", "urgent", "weekend", "payment_method", "international", "delayed", "sector"])
attributes = ["label", "size", "weight", "urgent", "weekend", "payment_method", "international", "delayed"]
x = dataframe[attributes]
y = dataframe['sector']
#dataframe[["label", ]]=dataframe["label"].str.split(";", expand=True)
dataframe = dataframe['label'].str.split(';', n=8, expand=True)
x = dataframe.iloc[:, 0:8]
y = dataframe.iloc[:, -1:]
print(attributes)
# print(y)
decision_tree = DecisionTreeClassifier()
decision_tree = decision_tree.fit(x, y)
# visualize and display decision tree
data = tree.export_graphviz(decision_tree, out_file=None, feature_names=attributes)
graph = pydotplus.graph_from_dot_data(data)
# graph.write_png(os.path.join('resources', 'decision_tree.png'))
img = pltimg.imread(os.path.join('resources', 'decision_tree.png'))
imgplot = plt.imshow(img)
plt.show()
return decision_tree
def make_decision(tree, label, size, weight, urgent, weekend, payment_method, international,
delayed):
decision = tree.predict([[label, size, weight, urgent, weekend, payment_method, international,
delayed]])
return decision