AI_PROJECT/Drzewo.py

25 lines
916 B
Python
Raw Normal View History

from sklearn import tree as skltree
import pandas,os
import matplotlib.pyplot as plt
atributes=['plant_water_level','tractor_water_level','temperature','rain','season','current_time','growth','disease','fertility'] #Columns in CSV file should be in the same order
class Drzewo:
def __init__(self):
self.tree=self.treeLearn()
def treeLearn(self):
csvdata=pandas.read_csv('Data/dataTree.csv')
x=csvdata[atributes]
decision=csvdata['action']
self.tree=skltree.DecisionTreeClassifier()
self.tree=self.tree.fit(x,decision)
def plotTree(self):
plt.figure()
skltree.plot_tree(self.tree,filled=True,feature_names=atributes)
plt.title("Drzewo decyzyjne wytrenowane na przygotowanych danych")
plt.show()
def makeDecision(self,values):
action=self.tree.predict(values) #0- nie podlewac, 1-podlewac
return action