created first Decision Tree; minor changes to project structure

This commit is contained in:
s452645 2021-05-23 13:38:16 +02:00
parent f03899b535
commit c7845a0d1f
11 changed files with 83 additions and 25 deletions

Binary file not shown.

View File

@ -0,0 +1,80 @@
import os
import json
from joblib import dump, load
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
from mines.disarming.mine_parameters import MineParameters
class DecisionTree:
def __init__(self, clf_source: str = None, vec_source: str = None):
if clf_source is not None and vec_source is not None:
self.load(clf_source, vec_source)
else:
self.clf = None
self.vec = None
def build(self, training_file: str, depth: int):
path = os.path.join("..", "..", "resources", "data", training_file)
samples = list()
results = list()
with open(path, "r") as training_file:
for sample in training_file:
s, r = self._process_input_line(sample)
samples.append(s)
results.append(r)
# vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
self.vec = DictVectorizer()
# create and run Tree Clasifier upon provided data
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
# print a tree (not necessary)
print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
def save(self):
dump(self.clf, 'decision_tree.joblib')
dump(self.vec, 'dict_vectorizer.joblib')
def load(self, clf_file, vec_file):
self.clf = load(clf_file)
self.vec = load(vec_file)
def get_answer(self, mine_params):
return self.clf.predict(self.vec.transform(mine_params).toarray())
def test(self):
mistakes = 0
for _ in range(1000):
mine_params = MineParameters().jsonifyable_dict()
correct = mine_params['wire']
del mine_params['wire']
answer = self.get_answer(mine_params)
if correct != answer:
print(f"Answer: {answer}\nCorrect: {correct}")
mistakes += 1
print(f"Accuracy: {100 - (mistakes / 1000)}")
@staticmethod
def _process_input_line(line):
data = json.loads(line.strip())
result = data['wire']
del data['wire']
sample = data
return sample, result
if __name__ == "__main__":
decision_tree = DecisionTree()
decision_tree.build("params3.txt", 15)
decision_tree.test()
decision_tree.save()

Binary file not shown.

View File

@ -1,22 +0,0 @@
from joblib import dump, load
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer
# X is a list of dictionaries with samples, Y is a list of samples' results
X = list()
Y = list()
# TODO: load training data
# vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
vec = DictVectorizer()
# create and run Tree Clasifier upon provided data
clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(vec.fit_transform(X).toarray(), Y)
# save decision tree to file
dump(clf, 'decision_tree.joblib')
# print a tree (not necessary)
print(tree.export_text(clf, feature_names=vec.get_feature_names()))

View File

@ -3,7 +3,7 @@ from random import randint
import project_constants as const import project_constants as const
from display_assets import blit_graphics from display_assets import blit_graphics
from algorithms.searching_algorithms import a_star from algorithms.search import a_star
from minefield import Minefield from minefield import Minefield

View File

@ -1,5 +1,5 @@
import random import random
import hash_function as hf from mines.disarming import hash_function as hf
class MineParameters: class MineParameters:

View File

@ -4,7 +4,7 @@ import os
# this module is self contained, used to generate a json file # this module is self contained, used to generate a json file
DIR_DATA = os.path.join("../resources", "data") DIR_DATA = os.path.join("../../resources", "data")
# just to show, how mine parameters works # just to show, how mine parameters works