86 lines
3.2 KiB
Python
86 lines
3.2 KiB
Python
import os
|
|
import random
|
|
|
|
from disarming.parameters.hash_function import SeriesHash, SpecificityHash
|
|
from objects.mine_models.mine import Mine
|
|
from algorithms.learn.decision_tree.decision_tree import DecisionTree
|
|
from algorithms.learn.neural_network.neural_network import NNType, NeuralNetwork
|
|
|
|
SERIES_IMAGES_PATH = r"resources/data/neural_network/series/disarm"
|
|
SPECIFICITY_IMAGES_PATH = r"resources/data/neural_network/specificity/disarm"
|
|
|
|
class_to_series = \
|
|
{SeriesHash[name].value[2]: SeriesHash[name].value[1] for name, _ in SeriesHash.__members__.items()}
|
|
|
|
class_to_specificity = \
|
|
{SpecificityHash[name].value[2]: SpecificityHash[name].value[1] for name, _ in SpecificityHash.__members__.items()}
|
|
|
|
|
|
class DisarmingHandler:
|
|
def __init__(self, mine: Mine):
|
|
self.mine = mine
|
|
self.mine_params = dict()
|
|
|
|
self.series_img = None
|
|
self.specificity_img = None
|
|
|
|
self.recognized_series = None
|
|
self.recognized_specificity = None
|
|
|
|
self.correct_wire = None
|
|
self.chosen_wire = None
|
|
|
|
self._set_mine_params()
|
|
self._set_correct_wire()
|
|
|
|
def _set_mine_params(self):
|
|
self.mine_params = self.mine.investigate()
|
|
|
|
def _set_correct_wire(self):
|
|
self.correct_wire = self.mine.wire
|
|
|
|
def get_mine_params(self):
|
|
return [self.mine_params["mine_type"], self.mine_params["weight"], self.mine_params["danger_cls"],
|
|
self.mine_params["indicator"], self.mine_params["series"], self.mine_params["specificity"]]
|
|
|
|
def pick_series_image(self):
|
|
series_class = SeriesHash[self.mine_params["series"].upper().replace(" ", "_")].value[2]
|
|
imgs_dir = os.path.join(SERIES_IMAGES_PATH, series_class)
|
|
|
|
self.series_img = os.path.join(
|
|
imgs_dir, random.choice([x for x in os.listdir(imgs_dir) if os.path.isfile(os.path.join(imgs_dir, x))]))
|
|
return self.series_img
|
|
|
|
def pick_specificity_image(self):
|
|
specificity_class = SpecificityHash[self.mine_params["specificity"].upper().replace(" ", "_")].value[2]
|
|
|
|
imgs_dir = os.path.join(SPECIFICITY_IMAGES_PATH, specificity_class)
|
|
|
|
self.specificity_img = os.path.join(
|
|
imgs_dir, random.choice([x for x in os.listdir(imgs_dir) if os.path.isfile(os.path.join(imgs_dir, x))]))
|
|
|
|
return self.specificity_img
|
|
|
|
def recognize_series(self):
|
|
nn = NeuralNetwork(NNType.SERIES, load_from_file=True)
|
|
answer, confidence = nn.get_answer(self.series_img)
|
|
self.recognized_series = class_to_series[answer]
|
|
|
|
return self.recognized_series, self.mine_params["series"] == self.recognized_series
|
|
|
|
def recognize_specificity(self):
|
|
nn = NeuralNetwork(NNType.SPECIFICITY, load_from_file=True)
|
|
answer, confidence = nn.get_answer(self.specificity_img)
|
|
self.recognized_specificity = class_to_specificity[answer]
|
|
|
|
return self.recognized_specificity, self.mine_params["specificity"] == self.recognized_specificity
|
|
|
|
def choose_wire(self):
|
|
dt = DecisionTree(load_from_file=True)
|
|
self.chosen_wire = dt.get_answer(self.mine_params)[0]
|
|
|
|
return self.chosen_wire
|
|
|
|
def defuse(self):
|
|
return self.mine.disarm(self.chosen_wire)
|