code refactor

This commit is contained in:
matixezor 2021-05-18 00:01:19 +02:00
parent 363ea4b786
commit cfe68b8e7a
4 changed files with 38 additions and 15 deletions

View File

@ -1,9 +1,8 @@
import os import os
import pygame as pg import pygame as pg
from ap_mine import APMine
from image import Image from image import Image
from tile import Tile
main_path = os.path.dirname(os.getcwd()) main_path = os.path.dirname(os.getcwd())

View File

@ -2,15 +2,15 @@ import pygame as pg
from agent import Agent from agent import Agent
from game_ui import GameUi from game_ui import GameUi
from const import ICON, IMAGES
from environment import Environment from environment import Environment
from tilesFactory import TilesFactory from tilesFactory import TilesFactory
from const import ICON, IMAGES
from src.search_algoritms.a_star import a_star from src.search_algoritms.a_star import a_star
from search_algoritms.BFS import breadth_first_search_for_a_star from search_algoritms.BFS import breadth_first_search
from machine_learning.decision_tree import get_decision, tree_root from machine_learning.decision_tree import get_decision, tree_root
def handle_keys(env, agent, game_ui, factory): def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
for event in pg.fastevent.get(): for event in pg.fastevent.get():
if event.type == pg.KEYDOWN: if event.type == pg.KEYDOWN:
if event.key == pg.K_d or event.key == pg.K_RIGHT: if event.key == pg.K_d or event.key == pg.K_RIGHT:
@ -63,10 +63,10 @@ def main():
elif event.type == pg.KEYDOWN: elif event.type == pg.KEYDOWN:
if event.key == pg.K_t: if event.key == pg.K_t:
print('Starting to clear the sector') print('Starting to clear the sector')
#while env.mine_count: # while env.mine_count:
print('-' * 20) print('-' * 20)
# path, actions = breadth_first_search(env.field, agent.x, agent.y, agent.direction) # path, actions = breadth_first_search(env.field, agent.x, agent.y, agent.direction)
goal = breadth_first_search_for_a_star(env.field, agent.x, agent.y, agent.direction) goal = breadth_first_search(env.field, agent.x, agent.y, agent.direction, True)
path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal) path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal)
if not path and not env.field[agent.y][agent.x].mine: if not path and not env.field[agent.y][agent.x].mine:
print('Unable to find path, rocks are in the way') print('Unable to find path, rocks are in the way')
@ -85,10 +85,10 @@ def main():
'pressure_gt_two': True if goal_tile.mine.pressure > 2 else False 'pressure_gt_two': True if goal_tile.mine.pressure > 2 else False
} }
decision = get_decision(data, tree_root) decision = get_decision(data, tree_root)
print(data)
print(f'Decision: {decision}')
pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE})) pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE}))
handle_keys(env, agent, game_ui, factory) handle_keys(env, agent, game_ui, factory)
print(data)
print(f'Decision: {decision}')
print('Sector clear') print('Sector clear')
else: else:
pg.fastevent.post(event) pg.fastevent.post(event)

View File

@ -4,4 +4,4 @@ from typing import Union
class Mine: class Mine:
def __init__(self, pressure: Union[float, int], armed: bool): def __init__(self, pressure: Union[float, int], armed: bool):
self.armed = armed self.armed = armed
self.pressure = pressure self.pressure = pressure

View File

@ -1,20 +1,44 @@
from random import uniform,choice
from typing import List from typing import List
from random import uniform, choice
from importlib import import_module from importlib import import_module
from tile import Tile from tile import Tile
from const import IMAGES from const import IMAGES
from machine_learning.data_set import visibility,stability,ground,armed from machine_learning.data_set import visibility, stability, ground, armed
class TilesFactory: class TilesFactory:
def create_tile(self, number: int) -> Tile: def create_tile(self, number: int) -> Tile:
img = IMAGES[number] img = IMAGES[number]
if img.mine_type: if img.mine_type:
module = import_module(f'{img.mine_type.lower()}_mine') module = import_module(f'{img.mine_type.lower()}_mine')
return Tile(number, 0, self.get_random_value(visibility), self.get_random_value(stability), self.get_random_value(ground), getattr(module, f'{img.mine_type}Mine')(uniform(1, 5 return Tile(
if img.mine_type == 'AP' or img.mine_type == 'AT' else 50), self.get_random_value(armed))) number,
0,
self.get_random_value(visibility),
self.get_random_value(stability),
self.get_random_value(ground),
getattr(module, f'{img.mine_type}Mine')(
uniform(
0.5,
5 if img.mine_type == 'AP' else 50),
self.get_random_value(armed)
)
)
else: else:
return Tile(number, 5, self.get_random_value(visibility), self.get_random_value(stability), self.get_random_value(ground)) if 'grass' in img.name else Tile(number, 30, self.get_random_value(visibility), self.get_random_value(stability), self.get_random_value(ground)) return Tile(
number,
5,
self.get_random_value(visibility),
self.get_random_value(stability),
self.get_random_value(ground)
) if 'grass' in img.name else Tile(
number,
30,
self.get_random_value(visibility),
self.get_random_value(stability),
self.get_random_value(ground)
)
def get_tiles_list(self) -> List[Tile]: def get_tiles_list(self) -> List[Tile]:
return [self.create_tile(i) for i in range(13)] return [self.create_tile(i) for i in range(13)]