import pandas as pd try: from StringIO import StringIO except ImportError: from io import StringIO from io import StringIO from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn import metrics from sklearn.tree import export_graphviz import joblib from IPython.display import Image import pydotplus import os os.environ["PATH"] += os.pathsep + r'C:\Program Files (x86)\Graphviz2.38\bin' def main(): dot_data = StringIO() col_names = ['age', 'sex', 'fat', 'fiber', 'spicy', 'number'] #import danych model_tree = pd.read_csv("Nowy.csv", header=None, names=col_names) model_tree.head() #seperacja danych cechy feature_cols = ['age', 'sex', 'fat', 'fiber', 'spicy'] X = model_tree[feature_cols] #separacja danych etykieta y = model_tree.number #podział danych na zestaw treningowy i testowy; 70% trening 30% test X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) ############################### # stworzenie -obiektu- drzewa Decision Tree classifer clf = DecisionTreeClassifier() # drzewo treningowe clf = clf.fit(X_train, y_train) #generowanie wyników dla zbioru testowego y_pred = clf.predict(X_test) #print(y_pred) #Akuratność dla modelu danych print("Accuracy:", metrics.accuracy_score(y_test, y_pred)) dot_data = StringIO() #tworzenie graficznego drzewa export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names=feature_cols, class_names=['1', '2', '3', '4', '5', '6', '7']) graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) graph.write_png('polecanie_1.png') Image(graph.create_png()) # zapisanie modelu danych do pliku file_name = 'final_model.sav' joblib.dump(clf, file_name) # ************************************************************ # stworzenie -obiektu- drzewa Decision Tree classifer z kryterium entropii clf = DecisionTreeClassifier(criterion="entropy") #drzewo testowe-z entripią clf = clf.fit(X_train, y_train) #generowanie wyników dla zbioru testowego y_pred = clf.predict(X_test) #Akuratność dla modelu danych z warunkiem entropii print("Accuracy:", metrics.accuracy_score(y_test, y_pred)) #stowrzenie graficznego drzewa z warunkiem entropii dot_data = StringIO() export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names=feature_cols, class_names=['1', '2', '3', '4', '5', '6', '7']) graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) graph.write_png('polecanie_2_entropia.png') Image(graph.create_png()) return main()