SI_2020/package_location_classifier/classifier.py

75 lines
2.8 KiB
Python
Raw Normal View History

from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz, DecisionTreeRegressor
from sklearn.externals.six import StringIO
from IPython.display import Image
import pandas as pd
import numpy as np
import pydotplus
from sklearn.model_selection import train_test_split
from sklearn import metrics
PACKAGE_PLACE_TRESHOLD = {
"normal": 0.8,
"freezed": 0.85,
"fragile": 0.85,
"flammable": 0.9,
"keep_dry": 0.8
}
class PackageLocationClassifier():
def __init__(self):
data = StringIO()
cols_names = ["product", "category", "temperature", "humidity", "chance_of_survive", "place_here"]
feature_cols = ["category", "temperature", "humidity"]
products = pd.read_csv("package_location_classifier/trainset/trainset.csv", header=0, sep=",", names=cols_names)
testset = pd.read_csv("package_location_classifier/testset/testset.csv", header=None, sep=",", names=cols_names)
products = products.round({"chance_of_survive": 1})
testset = testset.round({"chance_of_survive": 1})
products.chance_of_survive *= 10
testset.chance_of_survive *= 10
test_X = pd.get_dummies(testset[feature_cols])
test_y = testset.chance_of_survive
products = products.sample(frac=1)
X_train = pd.get_dummies(products[feature_cols])
y_train = products.chance_of_survive
dummies_names = X_train.columns.tolist()
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01, random_state=1, shuffle=True)
clf = DecisionTreeRegressor(ccp_alpha=0.02, min_samples_leaf=5, max_depth=5)
self.predictor = clf.fit(X_train, y_train)
y_pred = self.predictor.predict(test_X)
evaluation = pd.DataFrame({'category': testset.category, 'temperature': testset.temperature , 'humid': testset.humidity ,'Actual': test_y, 'Predicted': y_pred})
2020-05-11 13:44:58 +02:00
evaluation = evaluation.round({'Actual': 3, 'Predicted': 3})
evaluation['Prediction_diff'] = abs(evaluation['Actual'] - evaluation['Predicted'])
print("Prediction differs from actual value by average {}".format(round(evaluation['Prediction_diff'].mean(), 2)))
2020-05-11 13:44:58 +02:00
# export_graphviz(clf, out_file=data, filled=True, rounded=True, special_characters=True, feature_names=dummies_names)
# graph = pydotplus.graph_from_dot_data(data.getvalue())
# graph.write_png('Drzewo.png')
# Image(graph.create_png())
def check_if_can_place(self, package, tile):
category = package.category
cat_treshold = PACKAGE_PLACE_TRESHOLD[category]
fields = [[
tile.air_temperature,
tile.humidity,
category == "flammable",
category == "fragile",
category=="freezed" ,
category == "keep_dry",
category == "normal"
]]
quality_of_place = round(self.predictor.predict(fields)[0]/10, 2)
if quality_of_place > cat_treshold:
return True
return False
if __name__ == '__main__':
cfer = PackageLocationClassifier()