29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
from sklearn import tree as skltree
|
|
import pandas,os
|
|
import matplotlib.pyplot as plt
|
|
|
|
atributes=['plant_water_level','growth','disease','fertility','tractor_water_level','temperature','rain','season','current_time'] #Columns in CSV file has to be in the same order
|
|
class Drzewo:
|
|
def __init__(self):
|
|
self.tree=self.treeLearn()
|
|
|
|
def treeLearn(self):
|
|
csvdata=pandas.read_csv('Data/dataTree.csv')
|
|
#csvdata = pandas.read_csv('Data/dataTree2.csv')
|
|
x=csvdata[atributes]
|
|
decision=csvdata['action']
|
|
self.tree=skltree.DecisionTreeClassifier()
|
|
self.tree=self.tree.fit(x.values,decision)
|
|
|
|
def plotTree(self):
|
|
plt.figure(figsize=(20,30))
|
|
skltree.plot_tree(self.tree,filled=True,feature_names=atributes)
|
|
plt.title("Drzewo decyzyjne wytrenowane na przygotowanych danych: ")
|
|
plt.savefig('tree.png')
|
|
#plt.show()
|
|
def makeDecision(self,values):
|
|
action=self.tree.predict([values]) #0- nie podlewac, 1-podlewac
|
|
if(action==[0]):
|
|
return "Nie"
|
|
if(action==[1]):
|
|
return "Tak" |