neural_network #4

Merged
s452622 merged 34 commits from neural_network into master 2021-06-08 23:47:22 +02:00
3 changed files with 33 additions and 26 deletions
Showing only changes of commit 2d33ddba5c - Show all commits

View File

@ -63,12 +63,6 @@ def main():
running = True running = True
# ruchy agenta:
# 0 - up
# 1 - right
# 2 - down
# 3 - left
while running: while running:
for event in pg.fastevent.get(): for event in pg.fastevent.get():
if event.type == pg.QUIT: if event.type == pg.QUIT:
@ -76,23 +70,23 @@ def main():
elif event.type == pg.KEYDOWN: elif event.type == pg.KEYDOWN:
if event.key == pg.K_t: if event.key == pg.K_t:
print('Starting to clear the sector') print('Starting to clear the sector')
while env.mine_count: # while env.mine_count:
print('-' * 20) print('-' * 20)
# path, actions = breadth_first_search(env.field, agent.x, agent.y, agent.direction)
goal = breadth_first_search(env.field, agent.x, agent.y, agent.direction, True)
path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal)
if not path and not env.field[agent.y][agent.x].mine: goal = breadth_first_search(env.field, agent.x, agent.y, agent.direction, True)
print('Unable to find path, rocks are in the way') path, actions = a_star(env.field, agent.x, agent.y, agent.direction, goal)
break
print(f'Path{path}')
print(f'Actions:{actions}') if not path and not env.field[agent.y][agent.x].mine:
for action in actions: print('Unable to find path, rocks are in the way')
pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': action})) break
print(f'Path{path}')
handle_keys(env, agent, game_ui, factory) print(f'Actions:{actions}')
pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE})) for action in actions:
pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': action}))
handle_keys(env, agent, game_ui, factory)
pg.fastevent.post(pg.event.Event(pg.KEYDOWN, {'key': pg.K_SPACE}))
print('Sector clear') print('Sector clear')
else: else:

View File

@ -2,8 +2,10 @@ import queue
from typing import List from typing import List
from src.tile import Tile from src.tile import Tile
from const import model, IMAGES
from search_algoritms.node import Node from search_algoritms.node import Node
from .helpers import successor, get_path_actions from .helpers import successor, get_path_actions
from machine_learning.neural_network.learning import prediction
def breadth_first_search( def breadth_first_search(
@ -20,7 +22,8 @@ def breadth_first_search(
while not node_queue.empty(): while not node_queue.empty():
node = node_queue.get() node = node_queue.get()
if field[node.y][node.x].mine: img_path = IMAGES[field[node.y][node.x].number].path
if field[node.y][node.x].occupied and prediction(img_path, model) == 'mine':
return get_path_actions(node) if not a_star else node.x, node.y return get_path_actions(node) if not a_star else node.x, node.y
explored.append(node) explored.append(node)

View File

@ -3,8 +3,9 @@ from typing import List, Tuple
import pygame as pg import pygame as pg
from src.tile import Tile from src.tile import Tile
from src.const import ROCK_INDEXES from src.const import IMAGES, model
from src.search_algoritms.node import Node from src.search_algoritms.node import Node
from machine_learning.neural_network.learning import prediction
def get_path_actions(node: Node): def get_path_actions(node: Node):
@ -31,10 +32,19 @@ def successor(field: List[List[Tile]], x: int, y: int, direction: int):
neighbours.append((x, y, (direction - 1) % 4, pg.K_a)) neighbours.append((x, y, (direction - 1) % 4, pg.K_a))
neighbours.append((x, y, (direction + 1) % 4, pg.K_d)) neighbours.append((x, y, (direction + 1) % 4, pg.K_d))
if coord == 'x' and 0 <= x + shift <= 9 and field[y][x + shift].number not in ROCK_INDEXES: if coord == 'x' and 0 <= x + shift <= 9:
neighbours.append((x + shift, y, direction, pg.K_w)) img_path = IMAGES[field[y][x + shift].number].path
elif coord == 'y' and 0 <= y + shift <= 9 and field[y + shift][x].number not in ROCK_INDEXES: impassable = True if field[y][x + shift].occupied and prediction(img_path, model) == 'rock' else False
neighbours.append((x, y + shift, direction, pg.K_w))
if not impassable:
neighbours.append((x + shift, y, direction, pg.K_w))
elif coord == 'y' and 0 <= y + shift <= 9:
img_path = IMAGES[field[y + shift][x].number].path
impassable = True if field[y + shift][x].occupied and prediction(img_path, model) == 'rock' else False
if not impassable:
neighbours.append((x, y + shift, direction, pg.K_w))
return neighbours return neighbours