49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
|
import json
|
||
|
import os
|
||
|
|
||
|
from sklearn import tree
|
||
|
from sklearn.feature_extraction import DictVectorizer
|
||
|
|
||
|
from survival.components.resource_component import ResourceComponent
|
||
|
|
||
|
|
||
|
class DecisionTree:
|
||
|
def __init__(self):
|
||
|
self.clf = None
|
||
|
self.vec = None
|
||
|
|
||
|
def build(self, depth: int):
|
||
|
path = os.path.join("..", "data.txt")
|
||
|
|
||
|
samples = list()
|
||
|
results = list()
|
||
|
|
||
|
with open(path, "r") as training_file:
|
||
|
for sample in training_file:
|
||
|
sample, result = self.process_input(sample)
|
||
|
samples.append(sample)
|
||
|
results.append(result)
|
||
|
|
||
|
self.vec = DictVectorizer()
|
||
|
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
|
||
|
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
|
||
|
# print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
|
||
|
|
||
|
def predict_answer(self, resource: ResourceComponent):
|
||
|
params = {
|
||
|
"weight": resource.weight,
|
||
|
"eatable": resource.eatable,
|
||
|
"toughness": resource.toughness
|
||
|
}
|
||
|
return self.clf.predict(self.vec.transform(params).toarray())
|
||
|
|
||
|
@staticmethod
|
||
|
def process_input(line):
|
||
|
data = json.loads(line.strip())
|
||
|
result = data['resource']
|
||
|
del data['resource']
|
||
|
sample = data
|
||
|
|
||
|
return sample, result
|
||
|
|