AI-Project/survival/ai/decision_tree/decision_tree.py

61 lines
1.7 KiB
Python
Raw Permalink Normal View History

2021-06-19 18:04:59 +02:00
import json
from joblib import dump, load
from matplotlib import pyplot as plt
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
class DecisionTree:
def __init__(self):
self.clf = None
self.vec = None
def build(self, depth: int):
path = "tree_data.json"
samples = []
results = []
with open(path, "r") as training_file:
for sample in training_file:
sample, result = self.process_input(sample)
samples.append(sample)
results.append(result)
self.vec = DictVectorizer()
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
def save_model(self, clf_file, vec_file):
dump(self.clf, clf_file)
dump(self.vec, vec_file)
def load_model(self, clf_file, vec_file):
self.clf = load(clf_file)
self.vec = load(vec_file)
def predict_answer(self, params):
return self.clf.predict(self.vec.transform(params).toarray())
def plot_tree(self):
print('Plotting tree...')
fig = plt.figure(figsize=(36, 27))
_ = tree.plot_tree(self.clf,
feature_names=self.vec.get_feature_names(),
filled=True)
fig.savefig("decistion_tree.png")
print('Success!')
@staticmethod
def process_input(line):
data = json.loads(line.strip())
result = data['result']
del data['result']
del data['food_result']
del data['water_result']
del data['wood_result']
sample = data
return sample, result