Prześlij pliki do 'Sklearn'
This commit is contained in:
parent
70908a51ba
commit
e1336ab5c9
86
Sklearn/Generate.py
Normal file
86
Sklearn/Generate.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Load libraries
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from sklearn import tree, metrics
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.tree import DecisionTreeClassifier, _tree
|
||||
|
||||
|
||||
def tree_to_code(tree, feature_names):
|
||||
# f = open('generatedTree.py', 'w')
|
||||
tree_ = tree.tree_
|
||||
feature_name = [
|
||||
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
|
||||
for i in tree_.feature
|
||||
]
|
||||
# print("def tree({}):".format(", ".join(feature_names)), file=f)
|
||||
print("def tree({}):".format(", ".join(feature_names)))
|
||||
|
||||
def recurse(node, depth):
|
||||
indent = " " * depth
|
||||
if tree_.feature[node] != _tree.TREE_UNDEFINED:
|
||||
name = feature_name[node]
|
||||
threshold = tree_.threshold[node]
|
||||
# print("{}if {} <= {}:".format(indent, name, threshold), file=f)
|
||||
print("{}if {} <= {}:".format(indent, name, threshold))
|
||||
recurse(tree_.children_left[node], depth + 1)
|
||||
# print("{}else: # if {} > {}".format(indent, name, threshold), file=f)
|
||||
print("{}else: # if {} > {}".format(indent, name, threshold))
|
||||
recurse(tree_.children_right[node], depth + 1)
|
||||
else:
|
||||
# print("{}return {}".format(indent, tree_.value[node],), file=f)
|
||||
print("{}return {}".format(indent, tree_.value[node]))
|
||||
|
||||
recurse(0, 1)
|
||||
# f.close()
|
||||
|
||||
|
||||
def loadLearningBase():
|
||||
col_names = ['Warzywo', 'Nawoz', 'Srodek', 'Stan', 'Dzialanie']
|
||||
base = pd.read_csv("Database.csv", header=None, names=col_names)
|
||||
feature_cols = ['Warzywo', 'Nawoz', 'Srodek', 'Stan']
|
||||
|
||||
""" print dataset"""
|
||||
# print(base.head())
|
||||
|
||||
X = base[feature_cols] # Features
|
||||
y = base.Dzialanie # Target variable
|
||||
|
||||
# Split dataset into training set and test set
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
|
||||
random_state=1) # 70% training and 30% test
|
||||
|
||||
data = generateDecisionTree(X_train, X_test, y_train, y_test)
|
||||
|
||||
"""generate data for image"""
|
||||
# tree.export_graphviz(data, out_file='treeData.dot', filled=True, rounded=True, special_characters=True,
|
||||
# feature_names=feature_cols)
|
||||
|
||||
"""Printing if_styled tree to console"""
|
||||
# tree_to_code(data, feature_cols)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def generateDecisionTree(X_train, X_test, y_train, y_test):
|
||||
# Create Decision Tree classifer object
|
||||
clf = DecisionTreeClassifier(criterion="entropy")
|
||||
|
||||
# Train Decision Tree Classifer
|
||||
clf = clf.fit(X_train, y_train)
|
||||
|
||||
# Predict the response for test dataset
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
"""Model Accuracy, how often is the classifier correct """
|
||||
# print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
|
||||
|
||||
return clf
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generated = loadLearningBase()
|
||||
|
||||
# Save generated tree
|
||||
filename = 'decisionTree.sav'
|
||||
pickle.dump(generated, open(filename, 'wb'))
|
36
Sklearn/Test.py
Normal file
36
Sklearn/Test.py
Normal file
@ -0,0 +1,36 @@
|
||||
import pickle
|
||||
|
||||
|
||||
def decision(prediction):
|
||||
if prediction == 0:
|
||||
return "Nie_podejmuj_działania"
|
||||
elif prediction == 1:
|
||||
return "Zastosuj_nawóz"
|
||||
elif prediction == 2:
|
||||
return "Zastosuj_środek"
|
||||
elif prediction == 4:
|
||||
return "Zbierz"
|
||||
elif prediction == 5:
|
||||
return "Roślina_już_zgniła-zbierz_i_wyrzuć"
|
||||
|
||||
|
||||
def test():
|
||||
for n in range(0, 2):
|
||||
if n == 0:
|
||||
print("############# Nie ma nawozu #############")
|
||||
else:
|
||||
print("############# Zastosowano nawóz #############")
|
||||
for s in range(0, 2):
|
||||
if s == 0:
|
||||
print("############# Nie ma środka ochrony #############")
|
||||
else:
|
||||
print("############# Zastosowano środek ochrony #############")
|
||||
for st in range(0, 101):
|
||||
val = tree.predict([[1, n, s, st]])
|
||||
print("Stan roślinki: ", st, " ", decision(val))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
filename = 'decisionTree.sav'
|
||||
tree = pickle.load(open(filename, 'rb'))
|
||||
test()
|
14
Sklearn/injectCode.py
Normal file
14
Sklearn/injectCode.py
Normal file
@ -0,0 +1,14 @@
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
|
||||
def prediction(warzywo, nawoz ,srodek, stan_wzrostu):
|
||||
filename = 'decisionTree.sav'
|
||||
tree = pickle.load(open(filename, 'rb'))
|
||||
print(tree.predict([[warzywo, nawoz, srodek, stan_wzrostu]]))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Map command line arguments to function arguments.
|
||||
prediction(*sys.argv[1:])
|
25
Sklearn/treeData.dot
Normal file
25
Sklearn/treeData.dot
Normal file
@ -0,0 +1,25 @@
|
||||
digraph Tree {
|
||||
node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;
|
||||
edge [fontname=helvetica] ;
|
||||
0 [label=<Stan ≤ 40.5<br/>entropy = 2.092<br/>samples = 280<br/>value = [78, 41, 29, 108, 24]>, fillcolor="#e2e2fb"] ;
|
||||
1 [label=<Nawoz ≤ 0.5<br/>entropy = 0.929<br/>samples = 119<br/>value = [78, 41, 0, 0, 0]>, fillcolor="#f3c3a1"] ;
|
||||
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
|
||||
2 [label=<Stan ≤ 10.5<br/>entropy = 0.796<br/>samples = 54<br/>value = [13, 41, 0, 0, 0]>, fillcolor="#a5ed78"] ;
|
||||
1 -> 2 ;
|
||||
3 [label=<entropy = 0.0<br/>samples = 13<br/>value = [13, 0, 0, 0, 0]>, fillcolor="#e58139"] ;
|
||||
2 -> 3 ;
|
||||
4 [label=<entropy = 0.0<br/>samples = 41<br/>value = [0, 41, 0, 0, 0]>, fillcolor="#7be539"] ;
|
||||
2 -> 4 ;
|
||||
5 [label=<entropy = 0.0<br/>samples = 65<br/>value = [65, 0, 0, 0, 0]>, fillcolor="#e58139"] ;
|
||||
1 -> 5 ;
|
||||
6 [label=<Stan ≤ 80.5<br/>entropy = 1.241<br/>samples = 161<br/>value = [0, 0, 29, 108, 24]>, fillcolor="#8a88ef"] ;
|
||||
0 -> 6 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||
7 [label=<entropy = 0.0<br/>samples = 108<br/>value = [0, 0, 0, 108, 0]>, fillcolor="#3c39e5"] ;
|
||||
6 -> 7 ;
|
||||
8 [label=<Srodek ≤ 0.5<br/>entropy = 0.994<br/>samples = 53<br/>value = [0, 0, 29, 0, 24]>, fillcolor="#ddfbf5"] ;
|
||||
6 -> 8 ;
|
||||
9 [label=<entropy = 0.0<br/>samples = 29<br/>value = [0, 0, 29, 0, 0]>, fillcolor="#39e5c5"] ;
|
||||
8 -> 9 ;
|
||||
10 [label=<entropy = 0.0<br/>samples = 24<br/>value = [0, 0, 0, 0, 24]>, fillcolor="#e539c0"] ;
|
||||
8 -> 10 ;
|
||||
}
|
Loading…
Reference in New Issue
Block a user