SIP/core/decision_tree.py

33 lines
1.0 KiB
Python
Raw Permalink Normal View History

2021-06-10 21:23:58 +02:00
import os
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
class DecisionTree:
def __init__(self, precision):
self.model = DecisionTreeClassifier(criterion='gini')
self.learningDataSize = precision
def learn(self):
dataset = pd.read_csv(r'core' + os.path.sep + 'resources' + os.path.sep + 'data' +
os.path.sep + 'data_tree.csv')
input_x = dataset.iloc[:, 0:8]
input_y = dataset.iloc[:, 8]
x_train, x_test, y_train, y_test = train_test_split(input_x, input_y, test_size=(1-self.learningDataSize))
self.model.fit(x_train, y_train)
def predict(self, mushroom):
return bool(self.model.predict([[
mushroom.isRed,
mushroom.hasRing,
mushroom.hasEmptyCore,
mushroom.hasDots,
mushroom.hasLamella,
mushroom.isShiftingColor,
mushroom.hasUnpleasantSmell,
mushroom.hasSlime,
]])[0])