90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
|
import pandas as pd
|
||
|
import joblib
|
||
|
import pydotplus
|
||
|
from IPython.display import Image
|
||
|
from sklearn import tree
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from sklearn.metrics import accuracy_score
|
||
|
import os
|
||
|
os.environ["PATH"] += os.pathsep + r'C:\Program Files (x86)\Graphviz2.38\bin'
|
||
|
|
||
|
|
||
|
# import danych
|
||
|
def dataImport():
|
||
|
dataset = pd.read_csv('learnData4.csv', sep=',', header=None)
|
||
|
# print(dataset)
|
||
|
return dataset
|
||
|
|
||
|
|
||
|
# zamiana dataset na dwie tablice zawierające cechy i wyniki
|
||
|
def splitDataSet(dataset):
|
||
|
X = dataset.values[:, 0:5]
|
||
|
Y = dataset.values[:, 5]
|
||
|
|
||
|
# podział na zbiory do nauki i testowe
|
||
|
x_train, x_test, y_train, y_test = train_test_split(
|
||
|
X, Y, test_size=0.3, random_state=100)
|
||
|
return X, Y, x_train, x_test, y_train, y_test
|
||
|
|
||
|
|
||
|
def main():
|
||
|
|
||
|
data = dataImport()
|
||
|
X, Y, x_train, x_test, y_train, y_test = splitDataSet(data)
|
||
|
|
||
|
# utworzenie modelu drzewa decyzyjnego
|
||
|
model = tree.DecisionTreeClassifier()
|
||
|
model2 = tree.DecisionTreeClassifier(criterion="entropy")
|
||
|
|
||
|
# przeprowadzenie "nauki" na zbiorach do nauki
|
||
|
model.fit(x_train, y_train)
|
||
|
model2.fit(x_train, y_train)
|
||
|
|
||
|
# generowanie wyników dla zbioru testowego
|
||
|
pred = model.predict(x_test)
|
||
|
pred2 = model2.predict(x_test)
|
||
|
|
||
|
# porównanie z faktycznymi wynikami = wyliczenie dokładności
|
||
|
acc = accuracy_score(y_test, pred) * 100 # aprox. 77.78%
|
||
|
acc2 = accuracy_score(y_test, pred2) * 100 # aprox. 83.33%
|
||
|
print("akuratnosc dla modelu Gini: " + str(acc))
|
||
|
print("akuratnosc dla modelu Entropy: " + str(acc2))
|
||
|
|
||
|
# przekazanie wygenerowanego modelu do pliku
|
||
|
filename = 'finalized_model.sav'
|
||
|
joblib.dump(model2, filename)
|
||
|
|
||
|
kluseczka = [22, 4, 4, 1, 0]
|
||
|
wynik = model.predict([kluseczka])
|
||
|
print(wynik)
|
||
|
|
||
|
wynik2 = model2.predict([kluseczka])
|
||
|
print(wynik2)
|
||
|
|
||
|
# wygenerowanie i zapisanie graficznej reprezentacji drzewa
|
||
|
# dot_data = tree.export_graphviz(model, out_file=None,
|
||
|
# feature_names=["age", "fat", "fiber", "sex", "spicy"],
|
||
|
# class_names=["0", "1"],
|
||
|
# filled=True, rounded=True,
|
||
|
# special_characters=True)
|
||
|
|
||
|
dot_data2 = tree.export_graphviz(model2, out_file=None,
|
||
|
feature_names=["age", "fat", "fiber", "sex", "spicy"],
|
||
|
class_names=["easy to digest", "hard to digest"],
|
||
|
filled=True, rounded=True,
|
||
|
special_characters=True)
|
||
|
|
||
|
# wygenerowanie grafiki przedstawiającej drzewo
|
||
|
# graph = pydotplus.graph_from_dot_data(dot_data)
|
||
|
# Image(graph.create_png())
|
||
|
# graph.write_png("digest.png")
|
||
|
|
||
|
graph2 = pydotplus.graph_from_dot_data(dot_data2)
|
||
|
Image(graph2.create_png())
|
||
|
graph2.write_png("digest_entropy.png")
|
||
|
|
||
|
|
||
|
|
||
|
return
|
||
|
|
||
|
main()
|