from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz, DecisionTreeRegressor from sklearn.externals.six import StringIO from IPython.display import Image import pandas as pd import numpy as np import pydotplus from sklearn.model_selection import train_test_split from sklearn import metrics PACKAGE_PLACE_TRESHOLD = { "normal": 0.8, "freezed": 0.85, "fragile": 0.85, "flammable": 0.9, "keep_dry": 0.8 } class PackageLocationClassifier(): def __init__(self): data = StringIO() cols_names = ["product", "category", "temperature", "humidity", "chance_of_survive", "place_here"] feature_cols = ["category", "temperature", "humidity"] products = pd.read_csv("package_location_classifier/trainset/trainset.csv", header=0, sep=",", names=cols_names) testset = pd.read_csv("package_location_classifier/testset/testset.csv", header=None, sep=",", names=cols_names) products = products.round({"chance_of_survive": 1}) testset = testset.round({"chance_of_survive": 1}) products.chance_of_survive *= 10 testset.chance_of_survive *= 10 test_X = pd.get_dummies(testset[feature_cols]) test_y = testset.chance_of_survive products = products.sample(frac=1) X_train = pd.get_dummies(products[feature_cols]) y_train = products.chance_of_survive dummies_names = X_train.columns.tolist() # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01, random_state=1, shuffle=True) clf = DecisionTreeRegressor(ccp_alpha=0.02, min_samples_leaf=5, max_depth=5) self.predictor = clf.fit(X_train, y_train) y_pred = self.predictor.predict(test_X) evaluation = pd.DataFrame({'category': testset.category, 'temperature': testset.temperature , 'humid': testset.humidity ,'Actual': test_y, 'Predicted': y_pred}) evaluation = evaluation.round({'Actual': 3, 'Predicted': 3}) evaluation['Prediction_diff'] = abs(evaluation['Actual'] - evaluation['Predicted']) print("Prediction differs from actual value by average {}".format(round(evaluation['Prediction_diff'].mean(), 2))) # export_graphviz(clf, out_file=data, filled=True, rounded=True, special_characters=True, feature_names=dummies_names) # graph = pydotplus.graph_from_dot_data(data.getvalue()) # graph.write_png('Drzewo.png') # Image(graph.create_png()) def check_if_can_place(self, package, tile): category = package.category cat_treshold = PACKAGE_PLACE_TRESHOLD[category] fields = [[ tile.air_temperature, tile.humidity, category == "flammable", category == "fragile", category=="freezed" , category == "keep_dry", category == "normal" ]] quality_of_place = round(self.predictor.predict(fields)[0]/10, 2) if quality_of_place > cat_treshold: return True return False if __name__ == '__main__': cfer = PackageLocationClassifier()