81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
|
import os
|
||
|
import json
|
||
|
from joblib import dump, load
|
||
|
from sklearn import tree
|
||
|
from sklearn.feature_extraction import DictVectorizer
|
||
|
|
||
|
from mines.disarming.mine_parameters import MineParameters
|
||
|
|
||
|
|
||
|
class DecisionTree:
|
||
|
def __init__(self, clf_source: str = None, vec_source: str = None):
|
||
|
if clf_source is not None and vec_source is not None:
|
||
|
self.load(clf_source, vec_source)
|
||
|
else:
|
||
|
self.clf = None
|
||
|
self.vec = None
|
||
|
|
||
|
def build(self, training_file: str, depth: int):
|
||
|
path = os.path.join("..", "..", "resources", "data", training_file)
|
||
|
|
||
|
samples = list()
|
||
|
results = list()
|
||
|
|
||
|
with open(path, "r") as training_file:
|
||
|
for sample in training_file:
|
||
|
s, r = self._process_input_line(sample)
|
||
|
samples.append(s)
|
||
|
results.append(r)
|
||
|
|
||
|
# vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
|
||
|
self.vec = DictVectorizer()
|
||
|
|
||
|
# create and run Tree Clasifier upon provided data
|
||
|
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
|
||
|
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
|
||
|
|
||
|
# print a tree (not necessary)
|
||
|
print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
|
||
|
|
||
|
def save(self):
|
||
|
dump(self.clf, 'decision_tree.joblib')
|
||
|
dump(self.vec, 'dict_vectorizer.joblib')
|
||
|
|
||
|
def load(self, clf_file, vec_file):
|
||
|
self.clf = load(clf_file)
|
||
|
self.vec = load(vec_file)
|
||
|
|
||
|
def get_answer(self, mine_params):
|
||
|
return self.clf.predict(self.vec.transform(mine_params).toarray())
|
||
|
|
||
|
def test(self):
|
||
|
mistakes = 0
|
||
|
for _ in range(1000):
|
||
|
mine_params = MineParameters().jsonifyable_dict()
|
||
|
correct = mine_params['wire']
|
||
|
del mine_params['wire']
|
||
|
|
||
|
answer = self.get_answer(mine_params)
|
||
|
|
||
|
if correct != answer:
|
||
|
print(f"Answer: {answer}\nCorrect: {correct}")
|
||
|
mistakes += 1
|
||
|
|
||
|
print(f"Accuracy: {100 - (mistakes / 1000)}")
|
||
|
|
||
|
@staticmethod
|
||
|
def _process_input_line(line):
|
||
|
data = json.loads(line.strip())
|
||
|
result = data['wire']
|
||
|
del data['wire']
|
||
|
sample = data
|
||
|
|
||
|
return sample, result
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
decision_tree = DecisionTree()
|
||
|
decision_tree.build("params3.txt", 15)
|
||
|
decision_tree.test()
|
||
|
decision_tree.save()
|