import os import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split class DecisionTree: def __init__(self, precision): self.model = DecisionTreeClassifier(criterion='gini') self.learningDataSize = precision def learn(self): dataset = pd.read_csv(r'core' + os.path.sep + 'resources' + os.path.sep + 'data' + os.path.sep + 'data_tree.csv') input_x = dataset.iloc[:, 0:8] input_y = dataset.iloc[:, 8] x_train, x_test, y_train, y_test = train_test_split(input_x, input_y, test_size=(1-self.learningDataSize)) self.model.fit(x_train, y_train) def predict(self, mushroom): return bool(self.model.predict([[ mushroom.isRed, mushroom.hasRing, mushroom.hasEmptyCore, mushroom.hasDots, mushroom.hasLamella, mushroom.isShiftingColor, mushroom.hasUnpleasantSmell, mushroom.hasSlime, ]])[0])