[hotfix/neural_network] added neural network recognize result viewer in game

This commit is contained in:
czorekk 2022-05-27 10:06:24 +02:00
parent 9d3e744218
commit 3182e09e9d
2 changed files with 29 additions and 14 deletions

View File

@ -1,18 +1,14 @@
import pygame as pg import pygame as pg
from enum import Enum
class trash(pg.sprite.Sprite): from random import randrange
from map.tile import Tile
def __init__(self,x,y,img, type): class Trash(Tile):
super().__init__() def __init__(self, img, x, y, width, height):
super().__init__(img, x, y, width, height)
self.width=16
self.height=16
self.type = type
self.x = x self.x = x
self.y = y self.y = y
self.image = pg.image.load(img) def get_coords(self):
self.image = pg.transform.scale(self.image, (self.x,self.y)) return (self.x, self.y)
self.rect = self.image.get_rect()

19
main.py
View File

@ -16,6 +16,7 @@ from path_search_algorthms import bfs
from path_search_algorthms import a_star, a_star_utils from path_search_algorthms import a_star, a_star_utils
from decision_tree import decisionTree from decision_tree import decisionTree
from NeuralNetwork import prediction from NeuralNetwork import prediction
from game_objects.trash import Trash
from game_objects import aiPlayer from game_objects import aiPlayer
import itertools import itertools
@ -40,6 +41,7 @@ class Game():
def __init__(self): def __init__(self):
pg.init() pg.init()
pg.font.init()
self.clock = pg.time.Clock() self.clock = pg.time.Clock()
self.dt = self.clock.tick(FPS) / 333.0 self.dt = self.clock.tick(FPS) / 333.0
self.screen = pg.display.set_mode((WIDTH, HEIGHT)) self.screen = pg.display.set_mode((WIDTH, HEIGHT))
@ -64,8 +66,11 @@ class Game():
def init_game(self): def init_game(self):
# initialize all variables and do all the setup for a new game # initialize all variables and do all the setup for a new game
self.text_display = ''
# sprite groups and map array for calculations # sprite groups and map array for calculations
(self.roadTiles, self.wallTiles, self.trashbinTiles), self.mapArray = map.get_tiles() (self.roadTiles, self.wallTiles, self.trashbinTiles), self.mapArray = map.get_tiles()
self.trashDisplay = pg.sprite.Group()
self.agentSprites = pg.sprite.Group() self.agentSprites = pg.sprite.Group()
# player obj # player obj
self.player = Player(self, 32, 32) self.player = Player(self, 32, 32)
@ -136,7 +141,16 @@ class Game():
random = randint(0, 48) random = randint(0, 48)
file = files[random] file = files[random]
result = prediction.getPrediction(dir + '/' +file, 'trained_nn_20.pth') result = prediction.getPrediction(dir + '/' +file, 'trained_nn_20.pth')
img = pg.image.load(dir + '/' +file).convert_alpha()
img = pg.transform.scale(img, (128, 128))
trash = Trash(img, 0, 0, 128, 128)
self.trashDisplay.add(trash)
self.text_display = result
self.draw()
print(result + ' ' + file) print(result + ' ' + file)
pg.time.wait(1000)
self.text_display = ''
self.draw()
# print(self.positive_actions[0]) # print(self.positive_actions[0])
@ -178,6 +192,11 @@ class Game():
map.render_tiles(self.roadTiles, self.screen, self.camera) map.render_tiles(self.roadTiles, self.screen, self.camera)
map.render_tiles(self.wallTiles, self.screen, self.camera, self.debug_mode) map.render_tiles(self.wallTiles, self.screen, self.camera, self.debug_mode)
map.render_tiles(self.trashbinTiles, self.screen, self.camera) map.render_tiles(self.trashbinTiles, self.screen, self.camera)
map.render_tiles(self.trashDisplay, self.screen, self.camera)
# draw text
text_surface = pg.font.SysFont('Comic Sans MS', 30).render(self.text_display, False, (0,0,0))
self.screen.blit(text_surface, (0,128))
# rerender additional sprites # rerender additional sprites
for sprite in self.agentSprites: for sprite in self.agentSprites: