AI_PROJECT/Drzewo.py

29 lines
1.1 KiB
Python

from sklearn import tree as skltree
import pandas,os
import matplotlib.pyplot as plt
atributes=['plant_water_level','growth','disease','fertility','tractor_water_level','temperature','rain','season','current_time'] #Columns in CSV file has to be in the same order
class Drzewo:
def __init__(self):
self.tree=self.treeLearn()
def treeLearn(self):
csvdata=pandas.read_csv('Data/dataTree2.csv')
#csvdata = pandas.read_csv('Data/dataTree2.csv')
x=csvdata[atributes]
decision=csvdata['action']
self.tree=skltree.DecisionTreeClassifier()
self.tree=self.tree.fit(x.values,decision)
def plotTree(self):
plt.figure(figsize=(20,30))
skltree.plot_tree(self.tree,filled=True,feature_names=atributes)
plt.title("Drzewo decyzyjne wytrenowane na przygotowanych danych: ")
plt.savefig('tree.png')
#plt.show()
def makeDecision(self,values):
action=self.tree.predict([values]) #0- nie podlewac, 1-podlewac
if(action==[0]):
return "Nie"
if(action==[1]):
return "Tak"