import pandas as pd import joblib import pydotplus from IPython.display import Image from sklearn import tree from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import os os.environ["PATH"] += os.pathsep + r'C:\Program Files (x86)\Graphviz2.38\bin' # import danych def dataImport(): dataset = pd.read_csv('learnData4.csv', sep=',', header=None) # print(dataset) return dataset # zamiana dataset na dwie tablice zawierające cechy i wyniki def splitDataSet(dataset): X = dataset.values[:, 0:5] Y = dataset.values[:, 5] # podział na zbiory do nauki i testowe x_train, x_test, y_train, y_test = train_test_split( X, Y, test_size=0.3, random_state=100) return X, Y, x_train, x_test, y_train, y_test def main(): data = dataImport() X, Y, x_train, x_test, y_train, y_test = splitDataSet(data) # utworzenie modelu drzewa decyzyjnego model = tree.DecisionTreeClassifier() model2 = tree.DecisionTreeClassifier(criterion="entropy") # przeprowadzenie "nauki" na zbiorach do nauki model.fit(x_train, y_train) model2.fit(x_train, y_train) # generowanie wyników dla zbioru testowego pred = model.predict(x_test) pred2 = model2.predict(x_test) # porównanie z faktycznymi wynikami = wyliczenie dokładności acc = accuracy_score(y_test, pred) * 100 # aprox. 77.78% acc2 = accuracy_score(y_test, pred2) * 100 # aprox. 83.33% print("akuratnosc dla modelu Gini: " + str(acc)) print("akuratnosc dla modelu Entropy: " + str(acc2)) # przekazanie wygenerowanego modelu do pliku filename = 'finalized_model.sav' joblib.dump(model2, filename) kluseczka = [22, 4, 4, 1, 0] wynik = model.predict([kluseczka]) print(wynik) wynik2 = model2.predict([kluseczka]) print(wynik2) # wygenerowanie i zapisanie graficznej reprezentacji drzewa # dot_data = tree.export_graphviz(model, out_file=None, # feature_names=["age", "fat", "fiber", "sex", "spicy"], # class_names=["0", "1"], # filled=True, rounded=True, # special_characters=True) dot_data2 = tree.export_graphviz(model2, out_file=None, feature_names=["age", "fat", "fiber", "sex", "spicy"], class_names=["easy to digest", "hard to digest"], filled=True, rounded=True, special_characters=True) # wygenerowanie grafiki przedstawiającej drzewo # graph = pydotplus.graph_from_dot_data(dot_data) # Image(graph.create_png()) # graph.write_png("digest.png") graph2 = pydotplus.graph_from_dot_data(dot_data2) Image(graph2.create_png()) graph2.write_png("digest_entropy.png") return main()