49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import json
|
|
import os
|
|
|
|
from sklearn import tree
|
|
from sklearn.feature_extraction import DictVectorizer
|
|
|
|
from survival.components.resource_component import ResourceComponent
|
|
|
|
|
|
class DecisionTree:
|
|
def __init__(self):
|
|
self.clf = None
|
|
self.vec = None
|
|
|
|
def build(self, depth: int):
|
|
path = os.path.join("..", "data.txt")
|
|
|
|
samples = list()
|
|
results = list()
|
|
|
|
with open(path, "r") as training_file:
|
|
for sample in training_file:
|
|
sample, result = self.process_input(sample)
|
|
samples.append(sample)
|
|
results.append(result)
|
|
|
|
self.vec = DictVectorizer()
|
|
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
|
|
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
|
|
# print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
|
|
|
|
def predict_answer(self, resource: ResourceComponent):
|
|
params = {
|
|
"weight": resource.weight,
|
|
"eatable": resource.eatable,
|
|
"toughness": resource.toughness
|
|
}
|
|
return self.clf.predict(self.vec.transform(params).toarray())
|
|
|
|
@staticmethod
|
|
def process_input(line):
|
|
data = json.loads(line.strip())
|
|
result = data['resource']
|
|
del data['resource']
|
|
sample = data
|
|
|
|
return sample, result
|
|
|