DSZI_2020_Projekt/Restaurant/Marta/DecisionTreeGenerate.py

90 lines
2.8 KiB
Python
Raw Normal View History

2020-05-10 18:42:31 +02:00
import pandas as pd
import joblib
import pydotplus
from IPython.display import Image
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import os
os.environ["PATH"] += os.pathsep + r'C:\Program Files (x86)\Graphviz2.38\bin'
# import danych
def dataImport():
dataset = pd.read_csv('learnData4.csv', sep=',', header=None)
# print(dataset)
return dataset
# zamiana dataset na dwie tablice zawierające cechy i wyniki
def splitDataSet(dataset):
X = dataset.values[:, 0:5]
Y = dataset.values[:, 5]
# podział na zbiory do nauki i testowe
x_train, x_test, y_train, y_test = train_test_split(
X, Y, test_size=0.3, random_state=100)
return X, Y, x_train, x_test, y_train, y_test
def main():
data = dataImport()
X, Y, x_train, x_test, y_train, y_test = splitDataSet(data)
# utworzenie modelu drzewa decyzyjnego
model = tree.DecisionTreeClassifier()
model2 = tree.DecisionTreeClassifier(criterion="entropy")
# przeprowadzenie "nauki" na zbiorach do nauki
model.fit(x_train, y_train)
model2.fit(x_train, y_train)
# generowanie wyników dla zbioru testowego
pred = model.predict(x_test)
pred2 = model2.predict(x_test)
# porównanie z faktycznymi wynikami = wyliczenie dokładności
acc = accuracy_score(y_test, pred) * 100 # aprox. 77.78%
acc2 = accuracy_score(y_test, pred2) * 100 # aprox. 83.33%
print("akuratnosc dla modelu Gini: " + str(acc))
print("akuratnosc dla modelu Entropy: " + str(acc2))
# przekazanie wygenerowanego modelu do pliku
filename = 'finalized_model.sav'
joblib.dump(model2, filename)
kluseczka = [22, 4, 4, 1, 0]
wynik = model.predict([kluseczka])
print(wynik)
wynik2 = model2.predict([kluseczka])
print(wynik2)
# wygenerowanie i zapisanie graficznej reprezentacji drzewa
# dot_data = tree.export_graphviz(model, out_file=None,
# feature_names=["age", "fat", "fiber", "sex", "spicy"],
# class_names=["0", "1"],
# filled=True, rounded=True,
# special_characters=True)
dot_data2 = tree.export_graphviz(model2, out_file=None,
feature_names=["age", "fat", "fiber", "sex", "spicy"],
class_names=["easy to digest", "hard to digest"],
filled=True, rounded=True,
special_characters=True)
# wygenerowanie grafiki przedstawiającej drzewo
# graph = pydotplus.graph_from_dot_data(dot_data)
# Image(graph.create_png())
# graph.write_png("digest.png")
graph2 = pydotplus.graph_from_dot_data(dot_data2)
Image(graph2.create_png())
graph2.write_png("digest_entropy.png")
return
main()