From 2d33ddba5c2c3d0d687682b35bf7cd4b1c550000 Mon Sep 17 00:00:00 2001 From: matixezor Date: Tue, 1 Jun 2021 22:50:51 +0200 Subject: [PATCH] add neural network usage --- src/main.py | 34 ++++++++++++++------------------- src/search_algoritms/BFS.py | 5 ++++- src/search_algoritms/helpers.py | 20 ++++++++++++++----- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/main.py b/src/main.py index 434ad60..99d0ce4 100644 --- a/src/main.py +++ b/src/main.py @@ -63,12 +63,6 @@ def main(): running = True - # ruchy agenta: - # 0 - up - # 1 - right - # 2 - down - # 3 - left - while running: for event in pg.fastevent.get(): if event.type == pg.QUIT: @@ -76,23 +70,23 @@ def main(): elif event.type == pg.KEYDOWN: if event.key == pg.K_t: print('Starting to clear the sector') - while env.mine_count: - print('-' * 20) - # path, actions = breadth_first_search(env.field, agent.x, agent.y, agent.direction) - goal = breadth_first_search(env.field, agent.x, agent.y, agent.direction, True) - path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal) + # while env.mine_count: + print('-' * 20) - if not path and not env.field[agent.y][agent.x].mine: - print('Unable to find path, rocks are in the way') - break - print(f'Path{path}') + goal = breadth_first_search(env.field, agent.x, agent.y, agent.direction, True) + path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal) - print(f'Actions:{actions}') - for action in actions: - pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': action})) + if not path and not env.field[agent.y][agent.x].mine: + print('Unable to find path, rocks are in the way') + break + print(f'Path{path}') - handle_keys(env, agent, game_ui, factory) - pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE})) + print(f'Actions:{actions}') + for action in actions: + pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': action})) + + handle_keys(env, agent, game_ui, factory) + pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE})) print('Sector clear') else: diff --git a/src/search_algoritms/BFS.py b/src/search_algoritms/BFS.py index fed2c0a..e387bce 100644 --- a/src/search_algoritms/BFS.py +++ b/src/search_algoritms/BFS.py @@ -2,8 +2,10 @@ import queue from typing import List from src.tile import Tile +from const import model, IMAGES from search_algoritms.node import Node from .helpers import successor, get_path_actions +from machine_learning.neural_network.learning import prediction def breadth_first_search( @@ -20,7 +22,8 @@ def breadth_first_search( while not node_queue.empty(): node = node_queue.get() - if field[node.y][node.x].mine: + img_path = IMAGES[field[node.y][node.x].number].path + if field[node.y][node.x].occupied and prediction(img_path, model) == 'mine': return get_path_actions(node) if not a_star else node.x, node.y explored.append(node) diff --git a/src/search_algoritms/helpers.py b/src/search_algoritms/helpers.py index d3ca987..117c2d1 100644 --- a/src/search_algoritms/helpers.py +++ b/src/search_algoritms/helpers.py @@ -3,8 +3,9 @@ from typing import List, Tuple import pygame as pg from src.tile import Tile -from src.const import ROCK_INDEXES +from src.const import IMAGES, model from src.search_algoritms.node import Node +from machine_learning.neural_network.learning import prediction def get_path_actions(node: Node): @@ -31,10 +32,19 @@ def successor(field: List[List[Tile]], x: int, y: int, direction: int): neighbours.append((x, y, (direction - 1) % 4, pg.K_a)) neighbours.append((x, y, (direction + 1) % 4, pg.K_d)) - if coord == 'x' and 0 <= x + shift <= 9 and field[y][x + shift].number not in ROCK_INDEXES: - neighbours.append((x + shift, y, direction, pg.K_w)) - elif coord == 'y' and 0 <= y + shift <= 9 and field[y + shift][x].number not in ROCK_INDEXES: - neighbours.append((x, y + shift, direction, pg.K_w)) + if coord == 'x' and 0 <= x + shift <= 9: + img_path = IMAGES[field[y][x + shift].number].path + impassable = True if field[y][x + shift].occupied and prediction(img_path, model) == 'rock' else False + + if not impassable: + neighbours.append((x + shift, y, direction, pg.K_w)) + + elif coord == 'y' and 0 <= y + shift <= 9: + img_path = IMAGES[field[y + shift][x].number].path + impassable = True if field[y + shift][x].occupied and prediction(img_path, model) == 'rock' else False + + if not impassable: + neighbours.append((x, y + shift, direction, pg.K_w)) return neighbours