import json from joblib import dump, load from matplotlib import pyplot as plt from sklearn import tree from sklearn.feature_extraction import DictVectorizer class DecisionTree: def __init__(self): self.clf = None self.vec = None def build(self, depth: int): path = "tree_data.json" samples = [] results = [] with open(path, "r") as training_file: for sample in training_file: sample, result = self.process_input(sample) samples.append(sample) results.append(result) self.vec = DictVectorizer() self.clf = tree.DecisionTreeClassifier(max_depth=depth) self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results) def save_model(self, clf_file, vec_file): dump(self.clf, clf_file) dump(self.vec, vec_file) def load_model(self, clf_file, vec_file): self.clf = load(clf_file) self.vec = load(vec_file) def predict_answer(self, params): return self.clf.predict(self.vec.transform(params).toarray()) def plot_tree(self): print('Plotting tree...') fig = plt.figure(figsize=(36, 27)) _ = tree.plot_tree(self.clf, feature_names=self.vec.get_feature_names(), filled=True) fig.savefig("decistion_tree.png") print('Success!') @staticmethod def process_input(line): data = json.loads(line.strip()) result = data['result'] del data['result'] del data['food_result'] del data['water_result'] del data['wood_result'] sample = data return sample, result