37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
|
|
from sklearn.externals.six import StringIO
|
|
from IPython.display import Image
|
|
import pandas as pd
|
|
import numpy as np
|
|
import pydotplus
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn import metrics
|
|
|
|
data = StringIO()
|
|
|
|
cols_names= ["product", "category", "temperature", "humidity", "chance_of_survive", "place_here"]
|
|
# import pdb
|
|
# pdb.set_trace()
|
|
products = pd.read_csv("trainset/trainset.csv", header=0, sep=",", names=cols_names)
|
|
feature_cols = ["category", "temperature", "humidity"]
|
|
|
|
X = pd.get_dummies(products[feature_cols])
|
|
y = products.place_here
|
|
dummies_names = X.columns.tolist()
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1, shuffle=True)
|
|
|
|
clf = DecisionTreeClassifier(criterion="entropy", ccp_alpha=0.02, max_features=4)
|
|
clf = clf.fit(X_train, y_train)
|
|
|
|
y_pred = clf.predict(X_test)
|
|
print("Correctly placed in : {} % of cases".format(
|
|
round(100 * metrics.accuracy_score(y_test, y_pred), 3)
|
|
)
|
|
)
|
|
|
|
|
|
export_graphviz(clf, out_file=data, filled=True, rounded=True, special_characters=True, feature_names=dummies_names)
|
|
graph = pydotplus.graph_from_dot_data(data.getvalue())
|
|
graph.write_png('Ułożenie.png')
|
|
Image(graph.create_png()) |