78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz, DecisionTreeRegressor
|
|
from sklearn.externals.six import StringIO
|
|
from IPython.display import Image
|
|
import pandas as pd
|
|
import numpy as np
|
|
import pydotplus
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn import metrics
|
|
|
|
PACKAGE_PLACE_TRESHOLD = {
|
|
"normal": 0.8,
|
|
"freezed": 0.85,
|
|
"fragile": 0.85,
|
|
"flammable": 0.9,
|
|
"keep_dry": 0.8
|
|
}
|
|
|
|
|
|
class PackageLocationClassifier():
|
|
def __init__(self):
|
|
|
|
data = StringIO()
|
|
|
|
cols_names = ["product", "category", "temperature", "humidity", "chance_of_survive", "place_here"]
|
|
feature_cols = ["category", "temperature", "humidity"]
|
|
|
|
products = pd.read_csv("package_location_classifier/trainset/trainset.csv", header=0, sep=",", names=cols_names)
|
|
testset = pd.read_csv("package_location_classifier/testset/testset.csv", header=None, sep=",", names=cols_names)
|
|
products = products.round({"chance_of_survive": 1})
|
|
testset = testset.round({"chance_of_survive": 1})
|
|
products.chance_of_survive *= 10
|
|
testset.chance_of_survive *= 10
|
|
test_X = pd.get_dummies(testset[feature_cols])
|
|
test_y = testset.chance_of_survive
|
|
products = products.sample(frac=1)
|
|
X_train = pd.get_dummies(products[feature_cols])
|
|
y_train = products.chance_of_survive
|
|
dummies_names = X_train.columns.tolist()
|
|
|
|
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01, random_state=1, shuffle=True)
|
|
|
|
clf = DecisionTreeRegressor(ccp_alpha=0.02, min_samples_leaf=5, max_depth=5)
|
|
self.predictor = clf.fit(X_train, y_train)
|
|
|
|
|
|
|
|
y_pred = self.predictor.predict(test_X)
|
|
|
|
evaluation = pd.DataFrame({'category': testset.category, 'temperature': testset.temperature , 'humid': testset.humidity ,'Actual': test_y, 'Predicted': y_pred})
|
|
evaluation['Prediction_diff'] = abs(evaluation['Actual'] - evaluation['Predicted'])
|
|
print("Prediction differs from actual value by average {}".format(round(evaluation['Prediction_diff'].mean(), 2)))
|
|
export_graphviz(clf, out_file=data, filled=True, rounded=True, special_characters=True, feature_names=dummies_names)
|
|
graph = pydotplus.graph_from_dot_data(data.getvalue())
|
|
graph.write_png('Drzewo.png')
|
|
Image(graph.create_png())
|
|
|
|
def check_if_can_place(self, package, tile):
|
|
category = package.category
|
|
cat_treshold = PACKAGE_PLACE_TRESHOLD[category]
|
|
fields = [[
|
|
tile.air_temperature,
|
|
tile.humidity,
|
|
category == "flammable",
|
|
category == "fragile",
|
|
category=="freezed" ,
|
|
category == "keep_dry",
|
|
category == "normal"
|
|
]]
|
|
|
|
quality_of_place = round(self.predictor.predict(fields)[0]/10, 2)
|
|
# print("{} - dopasowanie {}".format(package,quality_of_place))
|
|
# pdb.set_trace()
|
|
if quality_of_place > cat_treshold:
|
|
return True
|
|
return False
|
|
|
|
if __name__ == '__main__':
|
|
cfer = PackageLocationClassifier() |