implemented new disarming in auto mode
This commit is contained in:
parent
985c7c0d77
commit
70fb44f69e
Binary file not shown.
@ -1,17 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
from matplotlib import pyplot
|
||||
from joblib import dump, load
|
||||
from sklearn import tree
|
||||
from sklearn.feature_extraction import DictVectorizer
|
||||
|
||||
from objects.mines.disarming.mine_parameters import MineParameters
|
||||
from objects.mines.disarming.parameter_json import generate_data
|
||||
from disarming.parameters.mine_parameters import MineParameters
|
||||
from disarming.parameters.parameter_json import generate_data
|
||||
|
||||
|
||||
class DecisionTree:
|
||||
def __init__(self, clf_source: str = None, vec_source: str = None):
|
||||
if clf_source is not None and vec_source is not None:
|
||||
def __init__(self, load_from_file: bool = False):
|
||||
if load_from_file:
|
||||
clf_source = r"algorithms/learn/decision_tree/decision_tree.joblib"
|
||||
vec_source = r"algorithms/learn/decision_tree/dict_vectorizer.joblib"
|
||||
self.load(clf_source, vec_source)
|
||||
else:
|
||||
self.clf = None
|
||||
@ -48,8 +49,8 @@ class DecisionTree:
|
||||
# fig.savefig("decistion_tree.png")
|
||||
|
||||
def save(self):
|
||||
dump(self.clf, 'decision_tree.joblib')
|
||||
dump(self.vec, 'dict_vectorizer.joblib')
|
||||
dump(self.clf, r'algorithms/learn/decision_tree/decision_tree.joblib')
|
||||
dump(self.vec, r'algorithms/learn/decision_tree/dict_vectorizer.joblib')
|
||||
|
||||
def load(self, clf_file, vec_file):
|
||||
self.clf = load(clf_file)
|
||||
@ -84,7 +85,7 @@ class DecisionTree:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# generate_data("training_set.txt", 12000)
|
||||
generate_data("training_set.txt", 12000)
|
||||
decision_tree = DecisionTree()
|
||||
decision_tree.build("training_set.txt", 15)
|
||||
decision_tree.test()
|
||||
|
Binary file not shown.
@ -16,14 +16,16 @@ class NNType(Enum):
|
||||
|
||||
|
||||
class NeuralNetwork:
|
||||
def __init__(self, network_type: NNType, saved_model_path=None, classes_path=None, img_height=180, img_width=180):
|
||||
def __init__(self, network_type: NNType, load_from_file: bool = False, img_height=180, img_width=180):
|
||||
self.type = network_type
|
||||
self.training_data_dir = pathlib.Path(fr"../../../resources/data/neural_network/{self.type.value[0]}/train")
|
||||
|
||||
self.img_height = img_height
|
||||
self.img_width = img_width
|
||||
|
||||
if saved_model_path is not None and classes_path is not None:
|
||||
if load_from_file:
|
||||
saved_model_path = fr"algorithms/learn/neural_network/{self.type.value[0]}/saved_model.h5"
|
||||
classes_path = fr"algorithms/learn/neural_network/{self.type.value[0]}/saved_model_classes.joblib"
|
||||
self.load(saved_model_path, classes_path)
|
||||
else:
|
||||
self.model = None
|
||||
@ -113,10 +115,8 @@ class NeuralNetwork:
|
||||
predictions = self.model.predict(img_array)
|
||||
score = tf.nn.softmax(predictions[0])
|
||||
|
||||
print(
|
||||
"This image most likely belongs to {} with a {:.2f} percent confidence."
|
||||
.format(self.class_names[np.argmax(score)], 100 * np.max(score))
|
||||
)
|
||||
# returns class, condifdence
|
||||
return self.class_names[np.argmax(score)], 100 * np.max(score)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -128,7 +128,7 @@ if __name__ == "__main__":
|
||||
# neural_network.save()
|
||||
|
||||
# Loading a model from file:
|
||||
neural_network = NeuralNetwork(NNType.SPECIFICITY, r"specificity/saved_model.h5", r"specificity/saved_model_classes.joblib")
|
||||
neural_network = NeuralNetwork(NNType.SPECIFICITY, load_from_file=True)
|
||||
|
||||
# Test
|
||||
image = r"../../../resources/data/neural_network/specificity/disarm/tanks/1-35-British-Tank-FV-214-Conqueror-MK-II-Amusing-Hobby-35A027-AH-35A027_b_0.JPG"
|
||||
|
@ -3,9 +3,9 @@ import pygame
|
||||
import project_constants as const
|
||||
from assets import asset_constants as asset
|
||||
|
||||
from objects.mines.mine_models.standard_mine import StandardMine
|
||||
from objects.mines.mine_models.chained_mine import ChainedMine
|
||||
from objects.mines.mine_models.time_mine import TimeMine
|
||||
from objects.mine_models.standard_mine import StandardMine
|
||||
from objects.mine_models.chained_mine import ChainedMine
|
||||
from objects.mine_models.time_mine import TimeMine
|
||||
|
||||
|
||||
# ================================= #
|
||||
|
85
disarming/disarming_handler.py
Normal file
85
disarming/disarming_handler.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
from disarming.parameters.hash_function import SeriesHash, SpecificityHash
|
||||
from objects.mine_models.mine import Mine
|
||||
from algorithms.learn.decision_tree.decision_tree import DecisionTree
|
||||
from algorithms.learn.neural_network.neural_network import NNType, NeuralNetwork
|
||||
|
||||
SERIES_IMAGES_PATH = r"resources/data/neural_network/series/disarm"
|
||||
SPECIFICITY_IMAGES_PATH = r"resources/data/neural_network/specificity/disarm"
|
||||
|
||||
class_to_series = \
|
||||
{SeriesHash[name].value[2]: SeriesHash[name].value[1] for name, _ in SeriesHash.__members__.items()}
|
||||
|
||||
class_to_specificity = \
|
||||
{SpecificityHash[name].value[2]: SpecificityHash[name].value[1] for name, _ in SpecificityHash.__members__.items()}
|
||||
|
||||
|
||||
class DisarmingHandler:
|
||||
def __init__(self, mine: Mine):
|
||||
self.mine = mine
|
||||
self.mine_params = dict()
|
||||
|
||||
self.series_img = None
|
||||
self.specificity_img = None
|
||||
|
||||
self.recognized_series = None
|
||||
self.recognized_specificity = None
|
||||
|
||||
self.correct_wire = None
|
||||
self.chosen_wire = None
|
||||
|
||||
self._set_mine_params()
|
||||
self._set_correct_wire()
|
||||
|
||||
def _set_mine_params(self):
|
||||
self.mine_params = self.mine.investigate()
|
||||
|
||||
def _set_correct_wire(self):
|
||||
self.correct_wire = self.mine.wire
|
||||
|
||||
def get_mine_params(self):
|
||||
return [self.mine_params["mine_type"], self.mine_params["weight"], self.mine_params["danger_cls"],
|
||||
self.mine_params["indicator"], self.mine_params["series"], self.mine_params["specificity"]]
|
||||
|
||||
def pick_series_image(self):
|
||||
series_class = SeriesHash[self.mine_params["series"].upper().replace(" ", "_")].value[2]
|
||||
imgs_dir = os.path.join(SERIES_IMAGES_PATH, series_class)
|
||||
|
||||
self.series_img = os.path.join(
|
||||
imgs_dir, random.choice([x for x in os.listdir(imgs_dir) if os.path.isfile(os.path.join(imgs_dir, x))]))
|
||||
return self.series_img
|
||||
|
||||
def pick_specificity_image(self):
|
||||
specificity_class = SpecificityHash[self.mine_params["specificity"].upper().replace(" ", "_")].value[2]
|
||||
|
||||
imgs_dir = os.path.join(SPECIFICITY_IMAGES_PATH, specificity_class)
|
||||
|
||||
self.specificity_img = os.path.join(
|
||||
imgs_dir, random.choice([x for x in os.listdir(imgs_dir) if os.path.isfile(os.path.join(imgs_dir, x))]))
|
||||
|
||||
return self.specificity_img
|
||||
|
||||
def recognize_series(self):
|
||||
nn = NeuralNetwork(NNType.SERIES, load_from_file=True)
|
||||
answer, confidence = nn.get_answer(self.series_img)
|
||||
self.recognized_series = class_to_series[answer]
|
||||
|
||||
return self.recognized_series, self.mine_params["series"] == self.recognized_series
|
||||
|
||||
def recognize_specificity(self):
|
||||
nn = NeuralNetwork(NNType.SPECIFICITY, load_from_file=True)
|
||||
answer, confidence = nn.get_answer(self.specificity_img)
|
||||
self.recognized_specificity = class_to_specificity[answer]
|
||||
|
||||
return self.recognized_specificity, self.mine_params["specificity"] == self.recognized_specificity
|
||||
|
||||
def choose_wire(self):
|
||||
dt = DecisionTree(load_from_file=True)
|
||||
self.chosen_wire = dt.get_answer(self.mine_params)[0]
|
||||
|
||||
return self.chosen_wire
|
||||
|
||||
def defuse(self):
|
||||
return self.mine.disarm(self.chosen_wire)
|
@ -23,14 +23,10 @@ class DangerClassHash(Enum):
|
||||
|
||||
|
||||
class SeriesHash(Enum):
|
||||
TCH_2990TONER = 128, "TCH 2990toner"
|
||||
TCH_2990INKJET = 110, "TCH 2990inkjet"
|
||||
TVY_2400H = 100, "TVY 2400h"
|
||||
SWX_5000 = 80, "SWX 5000"
|
||||
SWX_4000 = 50, "SWX 4000"
|
||||
WORKHORSE_3200 = 30, "WORKHORSE 3200"
|
||||
FX_500 = 15, "FX 500"
|
||||
TVY_2400 = 0, "TVY 2400"
|
||||
TCH_2990TONER = 128, "TCH 2990toner", "T"
|
||||
SWX_5000 = 80, "SWX 5000", "S"
|
||||
WORKHORSE_3200 = 30, "WORKHORSE 3200", "W"
|
||||
FX_500 = 15, "FX 500", "F"
|
||||
|
||||
|
||||
class IndicatorHash(Enum):
|
||||
@ -42,13 +38,9 @@ class IndicatorHash(Enum):
|
||||
|
||||
|
||||
class SpecificityHash(Enum):
|
||||
ANTI_AIRCRAFT = 55, "anti aircraft"
|
||||
ANTI_PERSONNEL = 43, "anti personnel"
|
||||
DEPTH_MINE = 37, "depth mine"
|
||||
ANTI_TANK = 26, "anti tank"
|
||||
PROXIMITY_MINE = 18, "proximity mine"
|
||||
PRESSURE_MINE = 9, "pressure mine"
|
||||
FRAGMENTATION_MINE = 0, "fragmentation mine"
|
||||
ANTI_AIRCRAFT = 55, "anti aircraft", "planes"
|
||||
DEPTH_MINE = 37, "depth mine", "ships"
|
||||
ANTI_TANK = 26, "anti tank", "tanks"
|
||||
|
||||
|
||||
class WeightHash(Enum):
|
@ -1,5 +1,5 @@
|
||||
import random
|
||||
from objects.mines.disarming import hash_function as hf
|
||||
from disarming.parameters import hash_function as hf
|
||||
|
||||
|
||||
class MineParameters:
|
@ -1,10 +1,10 @@
|
||||
import json
|
||||
import objects.mines.disarming.mine_parameters as param
|
||||
import disarming.parameters.mine_parameters as param
|
||||
import os
|
||||
import project_constants as const
|
||||
|
||||
# this module is self contained, used to generate a json file
|
||||
DIR_DATA = os.path.join(const.ROOT_DIR, "resources", "data")
|
||||
DIR_DATA = os.path.join(const.ROOT_DIR, "resources", "data", "decision_tree")
|
||||
|
||||
|
||||
# just to show, how mine parameters works
|
@ -1,8 +1,10 @@
|
||||
import pygame
|
||||
import pygame_gui
|
||||
|
||||
from project_constants import SCREEN_WIDTH, SCREEN_HEIGHT, V_NAME_OF_WINDOW
|
||||
from assets.asset_constants import ASSET_CONCRETE
|
||||
|
||||
from objects.mine_models.mine import Mine
|
||||
from disarming.disarming_handler import DisarmingHandler
|
||||
|
||||
# =========== #
|
||||
# == const == #
|
||||
@ -77,7 +79,7 @@ class SampleWindow:
|
||||
# main attributes
|
||||
self.running = True
|
||||
self.clock = pygame.time.Clock()
|
||||
self.manager = pygame_gui.UIManager(screen_size, 'theme.json') # TODO : change theme path
|
||||
self.manager = pygame_gui.UIManager(screen_size, 'disarming/theme.json') # TODO : change theme path
|
||||
# main attributes
|
||||
|
||||
def gui():
|
||||
@ -501,49 +503,76 @@ class SampleWindow:
|
||||
# == run == #
|
||||
# ========= #
|
||||
|
||||
def run(self):
|
||||
def run(self, mine: Mine):
|
||||
timed_event = pygame.USEREVENT + 1
|
||||
pygame.time.set_timer(timed_event, 5000)
|
||||
|
||||
step = 0
|
||||
handler = DisarmingHandler(mine)
|
||||
|
||||
while self.running:
|
||||
|
||||
time_delta = self.clock.tick(60) / 1000.0
|
||||
|
||||
keystate = pygame.key.get_pressed()
|
||||
|
||||
# all events except QUIT are for testing
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self.running = False
|
||||
if keystate[pygame.K_a]:
|
||||
self.show_params("hello", "yello", "gel", "jello", "yum", "hello")
|
||||
if keystate[pygame.K_s]:
|
||||
self.show_series("workhorse 3200wx", True)
|
||||
if keystate[pygame.K_d]:
|
||||
self.show_spec("anti aircraft", False)
|
||||
if keystate[pygame.K_f]:
|
||||
self.show_cable_calculated("red")
|
||||
if keystate[pygame.K_g]:
|
||||
self.show_cable_chosen("blue")
|
||||
if keystate[pygame.K_q]:
|
||||
self.show_pic_series(ASSET_CONCRETE)
|
||||
if keystate[pygame.K_w]:
|
||||
self.show_pic_spec(ASSET_CONCRETE)
|
||||
if event.type == timed_event:
|
||||
if step == 0:
|
||||
params = handler.get_mine_params()
|
||||
self.show_params(params[0], params[1], params[2], params[3], params[4], params[5])
|
||||
elif step == 1:
|
||||
self.show_cable_calculated(handler.correct_wire)
|
||||
elif step == 2:
|
||||
|
||||
# TODO:
|
||||
img = pygame.transform.scale(
|
||||
pygame.image.load(handler.pick_series_image()),
|
||||
(1080, 1080)
|
||||
)
|
||||
|
||||
self.show_pic_series(img)
|
||||
elif step == 3:
|
||||
|
||||
# TODO:
|
||||
img = pygame.transform.scale(
|
||||
pygame.image.load(handler.pick_specificity_image()),
|
||||
(60, 60)
|
||||
)
|
||||
|
||||
self.show_pic_spec(img)
|
||||
elif step == 4:
|
||||
answer, is_correct = handler.recognize_series()
|
||||
self.show_series(answer, is_correct)
|
||||
elif step == 5:
|
||||
answer, is_correct = handler.recognize_specificity()
|
||||
self.show_spec(answer, is_correct)
|
||||
elif step == 6:
|
||||
self.show_cable_chosen(handler.choose_wire())
|
||||
else:
|
||||
self.running = False
|
||||
|
||||
step += 1
|
||||
|
||||
self.manager.update(time_delta)
|
||||
self.window_surface.blit(self.background, (0, 0))
|
||||
self.manager.draw_ui(self.window_surface)
|
||||
|
||||
pygame.display.update()
|
||||
return handler.defuse()
|
||||
|
||||
|
||||
def disarming_popup():
|
||||
def disarming_popup(mine: Mine = None):
|
||||
# run the pop-up
|
||||
app = SampleWindow()
|
||||
app.run()
|
||||
result = app.run(mine)
|
||||
|
||||
# bring display back to normal
|
||||
pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
|
||||
pygame.display.set_caption(V_NAME_OF_WINDOW)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pygame.init()
|
2
game.py
2
game.py
@ -7,7 +7,7 @@ from algorithms.search import a_star
|
||||
|
||||
from minefield import Minefield
|
||||
|
||||
from objects.mines.mine_models.time_mine import TimeMine
|
||||
from objects.mine_models.time_mine import TimeMine
|
||||
|
||||
from ui.ui_components_manager import UiComponentsManager
|
||||
from ui.text_box import TextBox
|
||||
|
@ -7,9 +7,9 @@ import project_constants as const
|
||||
from objects.tile import Tile
|
||||
|
||||
# import mine models
|
||||
from objects.mines.mine_models.standard_mine import StandardMine
|
||||
from objects.mines.mine_models.time_mine import TimeMine
|
||||
from objects.mines.mine_models.chained_mine import ChainedMine
|
||||
from objects.mine_models.standard_mine import StandardMine
|
||||
from objects.mine_models.time_mine import TimeMine
|
||||
from objects.mine_models.chained_mine import ChainedMine
|
||||
|
||||
|
||||
class JsonGenerator:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import project_constants as const
|
||||
from objects import tile as tl, agent as ag
|
||||
from objects.mines.mine_models.time_mine import TimeMine
|
||||
from objects.mine_models.time_mine import TimeMine
|
||||
import json_generator as jg
|
||||
|
||||
|
||||
|
@ -3,6 +3,8 @@ from assets import asset_constants as asset
|
||||
import json
|
||||
from time import sleep
|
||||
from pygame import transform
|
||||
|
||||
import disarming.popup as popup
|
||||
from algorithms.learn.decision_tree.decision_tree import DecisionTree
|
||||
|
||||
|
||||
@ -21,21 +23,16 @@ class Agent:
|
||||
self.row, self.column = int(self.row), int(self.column)
|
||||
self.position = [self.row, self.column]
|
||||
self.on_screen_coordinates = const.get_tile_coordinates(tuple(self.position))
|
||||
self.decision_tree = DecisionTree(const.ROOT_DIR + "/algorithms/learn/decision_tree/decision_tree.joblib",
|
||||
const.ROOT_DIR + "/algorithms/learn/decision_tree/dict_vectorizer.joblib")
|
||||
self.direction = const.Direction(data["agents_initial_state"]["direction"])
|
||||
self.rotation_angle = -const.Direction(self.direction).value * 90
|
||||
self.going_forward = False
|
||||
self.rotating_left = False
|
||||
self.rotating_right = False
|
||||
|
||||
def defuse_a_mine(self, mine):
|
||||
mine_params = mine.investigate()
|
||||
chosen_wire = self.decision_tree.get_answer(mine_params)
|
||||
# TODO temporarily printing chosen wire
|
||||
print("agent's chosen wire: " + str(chosen_wire[0]))
|
||||
sleep(3)
|
||||
return mine.disarm(chosen_wire)
|
||||
@staticmethod
|
||||
def defuse_a_mine(mine):
|
||||
is_success = popup.disarming_popup(mine)
|
||||
return is_success
|
||||
|
||||
def update_and_draw(self, window, delta_time, minefield):
|
||||
self.update(delta_time, minefield)
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .mine import Mine
|
||||
from objects.mines.disarming.hash_function import TypeHash
|
||||
from disarming.parameters.hash_function import TypeHash
|
||||
|
||||
|
||||
class ChainedMine(Mine):
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
# type hints
|
||||
from typing import Tuple
|
||||
|
||||
from objects.mines.disarming.mine_parameters import MineParameters
|
||||
from disarming.parameters.mine_parameters import MineParameters
|
||||
|
||||
# Mine cannot be instantiated
|
||||
# all abstract methods must be implemented in derived classes
|
||||
@ -34,7 +34,4 @@ class Mine(ABC):
|
||||
del mine_parameters["wire"]
|
||||
self.wire = wire
|
||||
|
||||
# TODO temporarily printing parameters and right wire
|
||||
print("parameters:", mine_parameters, "\nright wire: " + wire)
|
||||
|
||||
return mine_parameters
|
@ -1,5 +1,5 @@
|
||||
from .mine import Mine
|
||||
from objects.mines.disarming.hash_function import TypeHash
|
||||
from disarming.parameters.hash_function import TypeHash
|
||||
|
||||
|
||||
class StandardMine(Mine):
|
@ -1,5 +1,5 @@
|
||||
from .mine import Mine
|
||||
from objects.mines.disarming.hash_function import TypeHash
|
||||
from disarming.parameters.hash_function import TypeHash
|
||||
|
||||
|
||||
class TimeMine(Mine):
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user