AI_PROJECT/Drzewo.py

25 lines
956 B
Python

from sklearn import tree as skltree
import pandas,os
import matplotlib.pyplot as plt
atributes=['plant_water_level','growth','disease','fertility','tractor_water_level','temperature','rain','season','current_time'] #Columns in CSV file has to be in the same order
class Drzewo:
def __init__(self):
self.tree=self.treeLearn()
def treeLearn(self):
csvdata=pandas.read_csv('Data/dataTree.csv')
x=csvdata[atributes]
decision=csvdata['action']
self.tree=skltree.DecisionTreeClassifier()
self.tree=self.tree.fit(x.values,decision)
def plotTree(self):
plt.figure()
skltree.plot_tree(self.tree,filled=True,feature_names=atributes)
plt.title("Drzewo decyzyjne wytrenowane na przygotowanych danych")
plt.savefig('tree.png')
plt.show()
def makeDecision(self,values):
action=self.tree.predict([values]) #0- nie podlewac, 1-podlewac
return action