SI_2020/package_location_classifier/classifier.py

37 lines
1.2 KiB
Python
Raw Normal View History

from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pandas as pd
import numpy as np
import pydotplus
from sklearn.model_selection import train_test_split
from sklearn import metrics
data = StringIO()
cols_names= ["product", "category", "temperature", "humidity", "chance_of_survive", "place_here"]
# import pdb
# pdb.set_trace()
products = pd.read_csv("trainset/trainset.csv", header=0, sep=",", names=cols_names)
feature_cols = ["category", "temperature", "humidity"]
X = pd.get_dummies(products[feature_cols])
y = products.place_here
dummies_names = X.columns.tolist()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1, shuffle=True)
clf = DecisionTreeClassifier(criterion="entropy", ccp_alpha=0.02, max_features=4)
clf = clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print("Correctly placed in : {} % of cases".format(
round(100 * metrics.accuracy_score(y_test, y_pred), 3)
)
)
export_graphviz(clf, out_file=data, filled=True, rounded=True, special_characters=True, feature_names=dummies_names)
graph = pydotplus.graph_from_dot_data(data.getvalue())
graph.write_png('Ułożenie.png')
Image(graph.create_png())