AI-Project/survival/decision_tree.py
2021-05-24 17:20:08 +02:00

49 lines
1.4 KiB
Python

import json
import os
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
from survival.components.resource_component import ResourceComponent
class DecisionTree:
def __init__(self):
self.clf = None
self.vec = None
def build(self, depth: int):
path = os.path.join("..", "data.txt")
samples = list()
results = list()
with open(path, "r") as training_file:
for sample in training_file:
sample, result = self.process_input(sample)
samples.append(sample)
results.append(result)
self.vec = DictVectorizer()
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
# print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
def predict_answer(self, resource: ResourceComponent):
params = {
"weight": resource.weight,
"eatable": resource.eatable,
"toughness": resource.toughness
}
return self.clf.predict(self.vec.transform(params).toarray())
@staticmethod
def process_input(line):
data = json.loads(line.strip())
result = data['resource']
del data['resource']
sample = data
return sample, result