61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
import json
|
|
|
|
from joblib import dump, load
|
|
from matplotlib import pyplot as plt
|
|
from sklearn import tree
|
|
from sklearn.feature_extraction import DictVectorizer
|
|
|
|
|
|
class DecisionTree:
|
|
def __init__(self):
|
|
self.clf = None
|
|
self.vec = None
|
|
|
|
def build(self, depth: int):
|
|
path = "tree_data.json"
|
|
|
|
samples = []
|
|
results = []
|
|
|
|
with open(path, "r") as training_file:
|
|
for sample in training_file:
|
|
sample, result = self.process_input(sample)
|
|
samples.append(sample)
|
|
results.append(result)
|
|
|
|
self.vec = DictVectorizer()
|
|
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
|
|
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
|
|
|
|
def save_model(self, clf_file, vec_file):
|
|
dump(self.clf, clf_file)
|
|
dump(self.vec, vec_file)
|
|
|
|
def load_model(self, clf_file, vec_file):
|
|
self.clf = load(clf_file)
|
|
self.vec = load(vec_file)
|
|
|
|
def predict_answer(self, params):
|
|
return self.clf.predict(self.vec.transform(params).toarray())
|
|
|
|
def plot_tree(self):
|
|
print('Plotting tree...')
|
|
fig = plt.figure(figsize=(36, 27))
|
|
_ = tree.plot_tree(self.clf,
|
|
feature_names=self.vec.get_feature_names(),
|
|
filled=True)
|
|
fig.savefig("decistion_tree.png")
|
|
print('Success!')
|
|
|
|
@staticmethod
|
|
def process_input(line):
|
|
data = json.loads(line.strip())
|
|
result = data['result']
|
|
del data['result']
|
|
del data['food_result']
|
|
del data['water_result']
|
|
del data['wood_result']
|
|
sample = data
|
|
|
|
return sample, result
|