Projekt_Sztuczna_Inteligencja/disarming/disarming_handler.py

86 lines
3.2 KiB
Python
Raw Normal View History

2021-06-06 22:00:42 +02:00
import os
import random
from disarming.parameters.hash_function import SeriesHash, SpecificityHash
from objects.mine_models.mine import Mine
from algorithms.learn.decision_tree.decision_tree import DecisionTree
from algorithms.learn.neural_network.neural_network import NNType, NeuralNetwork
SERIES_IMAGES_PATH = r"resources/data/neural_network/series/disarm"
SPECIFICITY_IMAGES_PATH = r"resources/data/neural_network/specificity/disarm"
class_to_series = \
{SeriesHash[name].value[2]: SeriesHash[name].value[1] for name, _ in SeriesHash.__members__.items()}
class_to_specificity = \
{SpecificityHash[name].value[2]: SpecificityHash[name].value[1] for name, _ in SpecificityHash.__members__.items()}
class DisarmingHandler:
def __init__(self, mine: Mine):
self.mine = mine
self.mine_params = dict()
self.series_img = None
self.specificity_img = None
self.recognized_series = None
self.recognized_specificity = None
self.correct_wire = None
self.chosen_wire = None
self._set_mine_params()
self._set_correct_wire()
def _set_mine_params(self):
self.mine_params = self.mine.investigate()
def _set_correct_wire(self):
self.correct_wire = self.mine.wire
def get_mine_params(self):
return [self.mine_params["mine_type"], self.mine_params["weight"], self.mine_params["danger_cls"],
self.mine_params["indicator"], self.mine_params["series"], self.mine_params["specificity"]]
def pick_series_image(self):
series_class = SeriesHash[self.mine_params["series"].upper().replace(" ", "_")].value[2]
imgs_dir = os.path.join(SERIES_IMAGES_PATH, series_class)
self.series_img = os.path.join(
imgs_dir, random.choice([x for x in os.listdir(imgs_dir) if os.path.isfile(os.path.join(imgs_dir, x))]))
return self.series_img
def pick_specificity_image(self):
specificity_class = SpecificityHash[self.mine_params["specificity"].upper().replace(" ", "_")].value[2]
imgs_dir = os.path.join(SPECIFICITY_IMAGES_PATH, specificity_class)
self.specificity_img = os.path.join(
imgs_dir, random.choice([x for x in os.listdir(imgs_dir) if os.path.isfile(os.path.join(imgs_dir, x))]))
return self.specificity_img
def recognize_series(self):
nn = NeuralNetwork(NNType.SERIES, load_from_file=True)
answer, confidence = nn.get_answer(self.series_img)
self.recognized_series = class_to_series[answer]
return self.recognized_series, self.mine_params["series"] == self.recognized_series
def recognize_specificity(self):
nn = NeuralNetwork(NNType.SPECIFICITY, load_from_file=True)
answer, confidence = nn.get_answer(self.specificity_img)
self.recognized_specificity = class_to_specificity[answer]
return self.recognized_specificity, self.mine_params["specificity"] == self.recognized_specificity
def choose_wire(self):
dt = DecisionTree(load_from_file=True)
self.chosen_wire = dt.get_answer(self.mine_params)[0]
return self.chosen_wire
def defuse(self):
return self.mine.disarm(self.chosen_wire)