33 lines
1.0 KiB
Python
33 lines
1.0 KiB
Python
import os
|
|
|
|
import pandas as pd
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
class DecisionTree:
|
|
def __init__(self, precision):
|
|
self.model = DecisionTreeClassifier(criterion='gini')
|
|
self.learningDataSize = precision
|
|
|
|
def learn(self):
|
|
dataset = pd.read_csv(r'core' + os.path.sep + 'resources' + os.path.sep + 'data' +
|
|
os.path.sep + 'data_tree.csv')
|
|
input_x = dataset.iloc[:, 0:8]
|
|
input_y = dataset.iloc[:, 8]
|
|
|
|
x_train, x_test, y_train, y_test = train_test_split(input_x, input_y, test_size=(1-self.learningDataSize))
|
|
self.model.fit(x_train, y_train)
|
|
|
|
def predict(self, mushroom):
|
|
return bool(self.model.predict([[
|
|
mushroom.isRed,
|
|
mushroom.hasRing,
|
|
mushroom.hasEmptyCore,
|
|
mushroom.hasDots,
|
|
mushroom.hasLamella,
|
|
mushroom.isShiftingColor,
|
|
mushroom.hasUnpleasantSmell,
|
|
mushroom.hasSlime,
|
|
]])[0])
|