DSZI_2020_Projekt/Restaurant/Natalia/__init__.py

87 lines
2.9 KiB
Python
Raw Normal View History

2020-05-24 18:02:53 +02:00
import pandas as pd
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
from io import StringIO
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.tree import export_graphviz
import joblib
from IPython.display import Image
import pydotplus
import os
os.environ["PATH"] += os.pathsep + r'C:\Program Files (x86)\Graphviz2.38\bin'
def main():
dot_data = StringIO()
col_names = ['age', 'sex', 'fat', 'fiber', 'spicy', 'number']
#import danych
model_tree = pd.read_csv("Nowy.csv", header=None, names=col_names)
model_tree.head()
#seperacja danych cechy
feature_cols = ['age', 'sex', 'fat', 'fiber', 'spicy']
X = model_tree[feature_cols]
#separacja danych etykieta
y = model_tree.number
2020-05-24 18:06:58 +02:00
2020-05-24 18:02:53 +02:00
#podział danych na zestaw treningowy i testowy; 70% trening 30% test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
random_state=1) # 70% training and 30% test
###############################
# stworzenie -obiektu- drzewa Decision Tree classifer
clf = DecisionTreeClassifier()
# drzewo treningowe
clf = clf.fit(X_train, y_train)
#generowanie wyników dla zbioru testowego
y_pred = clf.predict(X_test)
#print(y_pred)
#Akuratność dla modelu danych
print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
dot_data = StringIO()
#tworzenie graficznego drzewa
export_graphviz(clf, out_file=dot_data,
filled=True, rounded=True,
special_characters=True, feature_names=feature_cols,
class_names=['1', '2', '3', '4', '5', '6', '7'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('polecanie_1.png')
Image(graph.create_png())
# ************************************************************
# stworzenie -obiektu- drzewa Decision Tree classifer z kryterium entropii
clf = DecisionTreeClassifier(criterion="entropy")
#drzewo testowe-z entripią
clf = clf.fit(X_train, y_train)
#generowanie wyników dla zbioru testowego
y_pred = clf.predict(X_test)
#Akuratność dla modelu danych z warunkiem entropii
print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
#stowrzenie graficznego drzewa z warunkiem entropii
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,
filled=True, rounded=True,
special_characters=True, feature_names=feature_cols,
class_names=['1', '2', '3', '4', '5', '6', '7'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('polecanie_2_entropia.png')
Image(graph.create_png())
#zapisanie modelu danych do pliku
file_name = 'final_model.sav'
joblib.dump(clf, file_name)
return
main()