code refactor
This commit is contained in:
parent
363ea4b786
commit
cfe68b8e7a
@ -1,9 +1,8 @@
|
||||
import os
|
||||
import pygame as pg
|
||||
|
||||
from ap_mine import APMine
|
||||
from image import Image
|
||||
from tile import Tile
|
||||
|
||||
|
||||
main_path = os.path.dirname(os.getcwd())
|
||||
|
||||
|
14
src/main.py
14
src/main.py
@ -2,15 +2,15 @@ import pygame as pg
|
||||
|
||||
from agent import Agent
|
||||
from game_ui import GameUi
|
||||
from const import ICON, IMAGES
|
||||
from environment import Environment
|
||||
from tilesFactory import TilesFactory
|
||||
from const import ICON, IMAGES
|
||||
from src.search_algoritms.a_star import a_star
|
||||
from search_algoritms.BFS import breadth_first_search_for_a_star
|
||||
from search_algoritms.BFS import breadth_first_search
|
||||
from machine_learning.decision_tree import get_decision, tree_root
|
||||
|
||||
|
||||
def handle_keys(env, agent, game_ui, factory):
|
||||
def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
|
||||
for event in pg.fastevent.get():
|
||||
if event.type == pg.KEYDOWN:
|
||||
if event.key == pg.K_d or event.key == pg.K_RIGHT:
|
||||
@ -63,10 +63,10 @@ def main():
|
||||
elif event.type == pg.KEYDOWN:
|
||||
if event.key == pg.K_t:
|
||||
print('Starting to clear the sector')
|
||||
#while env.mine_count:
|
||||
# while env.mine_count:
|
||||
print('-' * 20)
|
||||
# path, actions = breadth_first_search(env.field, agent.x, agent.y, agent.direction)
|
||||
goal = breadth_first_search_for_a_star(env.field, agent.x, agent.y, agent.direction)
|
||||
goal = breadth_first_search(env.field, agent.x, agent.y, agent.direction, True)
|
||||
path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal)
|
||||
if not path and not env.field[agent.y][agent.x].mine:
|
||||
print('Unable to find path, rocks are in the way')
|
||||
@ -85,10 +85,10 @@ def main():
|
||||
'pressure_gt_two': True if goal_tile.mine.pressure > 2 else False
|
||||
}
|
||||
decision = get_decision(data, tree_root)
|
||||
print(data)
|
||||
print(f'Decision: {decision}')
|
||||
pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE}))
|
||||
handle_keys(env, agent, game_ui, factory)
|
||||
print(data)
|
||||
print(f'Decision: {decision}')
|
||||
print('Sector clear')
|
||||
else:
|
||||
pg.fastevent.post(event)
|
||||
|
@ -4,4 +4,4 @@ from typing import Union
|
||||
class Mine:
|
||||
def __init__(self, pressure: Union[float, int], armed: bool):
|
||||
self.armed = armed
|
||||
self.pressure = pressure
|
||||
self.pressure = pressure
|
||||
|
@ -1,20 +1,44 @@
|
||||
from random import uniform,choice
|
||||
from typing import List
|
||||
from random import uniform, choice
|
||||
from importlib import import_module
|
||||
|
||||
from tile import Tile
|
||||
from const import IMAGES
|
||||
from machine_learning.data_set import visibility,stability,ground,armed
|
||||
from machine_learning.data_set import visibility, stability, ground, armed
|
||||
|
||||
|
||||
class TilesFactory:
|
||||
def create_tile(self, number: int) -> Tile:
|
||||
img = IMAGES[number]
|
||||
if img.mine_type:
|
||||
module = import_module(f'{img.mine_type.lower()}_mine')
|
||||
return Tile(number, 0, self.get_random_value(visibility), self.get_random_value(stability), self.get_random_value(ground), getattr(module, f'{img.mine_type}Mine')(uniform(1, 5
|
||||
if img.mine_type == 'AP' or img.mine_type == 'AT' else 50), self.get_random_value(armed)))
|
||||
return Tile(
|
||||
number,
|
||||
0,
|
||||
self.get_random_value(visibility),
|
||||
self.get_random_value(stability),
|
||||
self.get_random_value(ground),
|
||||
getattr(module, f'{img.mine_type}Mine')(
|
||||
uniform(
|
||||
0.5,
|
||||
5 if img.mine_type == 'AP' else 50),
|
||||
self.get_random_value(armed)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return Tile(number, 5, self.get_random_value(visibility), self.get_random_value(stability), self.get_random_value(ground)) if 'grass' in img.name else Tile(number, 30, self.get_random_value(visibility), self.get_random_value(stability), self.get_random_value(ground))
|
||||
return Tile(
|
||||
number,
|
||||
5,
|
||||
self.get_random_value(visibility),
|
||||
self.get_random_value(stability),
|
||||
self.get_random_value(ground)
|
||||
) if 'grass' in img.name else Tile(
|
||||
number,
|
||||
30,
|
||||
self.get_random_value(visibility),
|
||||
self.get_random_value(stability),
|
||||
self.get_random_value(ground)
|
||||
)
|
||||
|
||||
def get_tiles_list(self) -> List[Tile]:
|
||||
return [self.create_tile(i) for i in range(13)]
|
||||
|
Loading…
Reference in New Issue
Block a user