created first Decision Tree; minor changes to project structure
This commit is contained in:
parent
f03899b535
commit
c7845a0d1f
BIN
algorithms/learn/decision_tree.joblib
Normal file
BIN
algorithms/learn/decision_tree.joblib
Normal file
Binary file not shown.
80
algorithms/learn/decision_tree.py
Normal file
80
algorithms/learn/decision_tree.py
Normal file
@ -0,0 +1,80 @@
|
||||
import os
|
||||
import json
|
||||
from joblib import dump, load
|
||||
from sklearn import tree
|
||||
from sklearn.feature_extraction import DictVectorizer
|
||||
|
||||
from mines.disarming.mine_parameters import MineParameters
|
||||
|
||||
|
||||
class DecisionTree:
|
||||
def __init__(self, clf_source: str = None, vec_source: str = None):
|
||||
if clf_source is not None and vec_source is not None:
|
||||
self.load(clf_source, vec_source)
|
||||
else:
|
||||
self.clf = None
|
||||
self.vec = None
|
||||
|
||||
def build(self, training_file: str, depth: int):
|
||||
path = os.path.join("..", "..", "resources", "data", training_file)
|
||||
|
||||
samples = list()
|
||||
results = list()
|
||||
|
||||
with open(path, "r") as training_file:
|
||||
for sample in training_file:
|
||||
s, r = self._process_input_line(sample)
|
||||
samples.append(s)
|
||||
results.append(r)
|
||||
|
||||
# vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
|
||||
self.vec = DictVectorizer()
|
||||
|
||||
# create and run Tree Clasifier upon provided data
|
||||
self.clf = tree.DecisionTreeClassifier(max_depth=depth)
|
||||
self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)
|
||||
|
||||
# print a tree (not necessary)
|
||||
print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))
|
||||
|
||||
def save(self):
|
||||
dump(self.clf, 'decision_tree.joblib')
|
||||
dump(self.vec, 'dict_vectorizer.joblib')
|
||||
|
||||
def load(self, clf_file, vec_file):
|
||||
self.clf = load(clf_file)
|
||||
self.vec = load(vec_file)
|
||||
|
||||
def get_answer(self, mine_params):
|
||||
return self.clf.predict(self.vec.transform(mine_params).toarray())
|
||||
|
||||
def test(self):
|
||||
mistakes = 0
|
||||
for _ in range(1000):
|
||||
mine_params = MineParameters().jsonifyable_dict()
|
||||
correct = mine_params['wire']
|
||||
del mine_params['wire']
|
||||
|
||||
answer = self.get_answer(mine_params)
|
||||
|
||||
if correct != answer:
|
||||
print(f"Answer: {answer}\nCorrect: {correct}")
|
||||
mistakes += 1
|
||||
|
||||
print(f"Accuracy: {100 - (mistakes / 1000)}")
|
||||
|
||||
@staticmethod
|
||||
def _process_input_line(line):
|
||||
data = json.loads(line.strip())
|
||||
result = data['wire']
|
||||
del data['wire']
|
||||
sample = data
|
||||
|
||||
return sample, result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
decision_tree = DecisionTree()
|
||||
decision_tree.build("params3.txt", 15)
|
||||
decision_tree.test()
|
||||
decision_tree.save()
|
BIN
algorithms/learn/dict_vectorizer.joblib
Normal file
BIN
algorithms/learn/dict_vectorizer.joblib
Normal file
Binary file not shown.
@ -1,22 +0,0 @@
|
||||
from joblib import dump, load
|
||||
from sklearn import tree
|
||||
from sklearn.feature_extraction import DictVectorizer
|
||||
|
||||
# X is a list of dictionaries with samples, Y is a list of samples' results
|
||||
X = list()
|
||||
Y = list()
|
||||
|
||||
# TODO: load training data
|
||||
|
||||
# vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
|
||||
vec = DictVectorizer()
|
||||
|
||||
# create and run Tree Clasifier upon provided data
|
||||
clf = tree.DecisionTreeClassifier(max_depth=3)
|
||||
clf = clf.fit(vec.fit_transform(X).toarray(), Y)
|
||||
|
||||
# save decision tree to file
|
||||
dump(clf, 'decision_tree.joblib')
|
||||
|
||||
# print a tree (not necessary)
|
||||
print(tree.export_text(clf, feature_names=vec.get_feature_names()))
|
2
game.py
2
game.py
@ -3,7 +3,7 @@ from random import randint
|
||||
import project_constants as const
|
||||
|
||||
from display_assets import blit_graphics
|
||||
from algorithms.searching_algorithms import a_star
|
||||
from algorithms.search import a_star
|
||||
|
||||
from minefield import Minefield
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import random
|
||||
import hash_function as hf
|
||||
from mines.disarming import hash_function as hf
|
||||
|
||||
|
||||
class MineParameters:
|
@ -4,7 +4,7 @@ import os
|
||||
|
||||
|
||||
# this module is self contained, used to generate a json file
|
||||
DIR_DATA = os.path.join("../resources", "data")
|
||||
DIR_DATA = os.path.join("../../resources", "data")
|
||||
|
||||
|
||||
# just to show, how mine parameters works
|
Loading…
Reference in New Issue
Block a user