Add reinforcement learning
This commit is contained in:
parent
869dcbc124
commit
deea62212c
@ -1,5 +1,3 @@
|
||||
import random
|
||||
|
||||
import pygame
|
||||
|
||||
from settings import SCREEN_WIDTH, SCREEN_HEIGHT
|
||||
@ -12,6 +10,31 @@ from survival.generators.resource_generator import ResourceGenerator
|
||||
from survival.generators.world_generator import WorldGenerator
|
||||
from survival.systems.draw_system import DrawSystem
|
||||
|
||||
|
||||
class Game:
|
||||
def __init__(self):
|
||||
self.world_generator = WorldGenerator(win, self.reset)
|
||||
self.game_map, self.world, self.camera = self.world_generator.create_world()
|
||||
self.run = True
|
||||
|
||||
def reset(self):
|
||||
self.world_generator.reset_world()
|
||||
|
||||
def update(self, ms):
|
||||
events = pygame.event.get()
|
||||
|
||||
for event in events:
|
||||
if event.type == pygame.QUIT:
|
||||
self.run = False
|
||||
if pygame.key.get_pressed()[pygame.K_DELETE]:
|
||||
self.reset()
|
||||
|
||||
win.fill((0, 0, 0))
|
||||
self.game_map.draw(self.camera)
|
||||
self.world.process(ms)
|
||||
pygame.display.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pygame.init()
|
||||
|
||||
@ -21,32 +44,7 @@ if __name__ == '__main__':
|
||||
pygame.display.set_caption("AI Project")
|
||||
|
||||
clock = pygame.time.Clock()
|
||||
game = Game()
|
||||
|
||||
game_map = GameMap(int(SCREEN_WIDTH / 32) * 2, 2 * int(SCREEN_HEIGHT / 32) + 1)
|
||||
camera = Camera(game_map.width * 32, game_map.height * 32, win)
|
||||
|
||||
world = WorldGenerator().create_world(camera, game_map)
|
||||
player = PlayerGenerator().create_player(world, game_map)
|
||||
world.get_processor(DrawSystem).initialize_interface(world.component_for_entity(player, InventoryComponent))
|
||||
building = BuildingGenerator().create_home(world, game_map)
|
||||
|
||||
ResourceGenerator(world, game_map).generate_resources(player)
|
||||
|
||||
run = True
|
||||
|
||||
while run:
|
||||
# Set the framerate
|
||||
ms = clock.tick(60)
|
||||
|
||||
events = pygame.event.get()
|
||||
|
||||
for event in events:
|
||||
if event.type == pygame.QUIT:
|
||||
run = False
|
||||
|
||||
keys = pygame.key.get_pressed()
|
||||
|
||||
win.fill((0, 0, 0))
|
||||
game_map.draw(camera)
|
||||
world.process(ms)
|
||||
pygame.display.update()
|
||||
while game.run:
|
||||
game.update(clock.tick(60))
|
||||
|
5
survival/components/consumption_component.py
Normal file
5
survival/components/consumption_component.py
Normal file
@ -0,0 +1,5 @@
|
||||
class ConsumptionComponent:
|
||||
def __init__(self, inventory_state=0):
|
||||
self.timer_value: float = 2000
|
||||
self.timer: float = self.timer_value
|
||||
self.last_inventory_state = inventory_state
|
@ -19,3 +19,12 @@ class InventoryComponent:
|
||||
|
||||
def has_item(self, item):
|
||||
return item in self.items and self.items[item] != 0
|
||||
|
||||
def total_items_count(self):
|
||||
total = 0
|
||||
for item, value in self.items.items():
|
||||
total += value
|
||||
return total
|
||||
|
||||
def clear(self):
|
||||
self.items = {}
|
||||
|
32
survival/components/learning_component.py
Normal file
32
survival/components/learning_component.py
Normal file
@ -0,0 +1,32 @@
|
||||
from survival.components.time_component import TimeComponent
|
||||
|
||||
|
||||
class LearningComponent:
|
||||
def __init__(self):
|
||||
self.made_step = False
|
||||
self.old_state = None
|
||||
self.action = None
|
||||
self.resource = None
|
||||
|
||||
self.reward = 0
|
||||
self.done = False
|
||||
self.score = 0
|
||||
self.record = 0
|
||||
|
||||
def load_step(self, old_state, action, resource):
|
||||
self.old_state = old_state
|
||||
self.action = action
|
||||
if resource is None:
|
||||
self.resource = None
|
||||
else:
|
||||
self.resource = resource
|
||||
self.made_step = True
|
||||
|
||||
def reset(self):
|
||||
self.made_step = False
|
||||
self.old_state = None
|
||||
self.action = None
|
||||
self.resource = None
|
||||
|
||||
self.reward = 0
|
||||
self.done = False
|
@ -1,6 +1,5 @@
|
||||
class PathfindingComponent:
|
||||
def __init__(self, target_pos, searching_for_resource=False):
|
||||
def __init__(self, target_pos):
|
||||
self.target_grid_pos = (int(target_pos[0] / 32), int(target_pos[1] / 32))
|
||||
self.searching_for_resource = False
|
||||
self.current_target = None
|
||||
self.path = None
|
||||
|
@ -1,5 +1,5 @@
|
||||
class TimeComponent:
|
||||
def __init__(self, minute, hour, day, timer):
|
||||
def __init__(self, minute=0, hour=0, day=0, timer=0):
|
||||
self.minute = minute
|
||||
self.hour = hour
|
||||
self.day = day
|
||||
@ -16,5 +16,17 @@ class TimeComponent:
|
||||
self.hour = temp2
|
||||
self.minute = temp
|
||||
|
||||
def total_minutes(self):
|
||||
return self.minute + self.hour * 60 + self.day * 1440
|
||||
|
||||
def __str__(self):
|
||||
return f'Day {self.day}, {self.hour}:{self.minute}'
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.total_minutes() == other.total_minutes()
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.total_minutes() > other.total_minutes()
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.total_minutes() < other.total_minutes()
|
||||
|
34
survival/components/vision_component.py
Normal file
34
survival/components/vision_component.py
Normal file
@ -0,0 +1,34 @@
|
||||
from pygame import Surface
|
||||
|
||||
from survival.settings import AGENT_VISION_RANGE, SCREEN_WIDTH, SCREEN_HEIGHT
|
||||
|
||||
|
||||
class VisionComponent:
|
||||
def __init__(self):
|
||||
self.agent_vision = AGENT_VISION_RANGE * 32 * 2
|
||||
self.width = SCREEN_WIDTH * 2
|
||||
self.height = SCREEN_HEIGHT * 2
|
||||
self.surface_l = Surface(((self.width - self.agent_vision) / 2, self.height))
|
||||
self.surface_r = Surface(((self.width - self.agent_vision) / 2, self.height))
|
||||
self.surface_t = Surface((self.agent_vision, (self.height - self.agent_vision) / 2))
|
||||
self.surface_b = Surface((self.agent_vision, (self.height - self.agent_vision) / 2))
|
||||
self.surface_l.fill((0, 0, 0))
|
||||
self.surface_l.set_alpha(200)
|
||||
self.surface_r.fill((0, 0, 0))
|
||||
self.surface_r.set_alpha(200)
|
||||
self.surface_t.fill((0, 0, 0))
|
||||
self.surface_t.set_alpha(200)
|
||||
self.surface_b.fill((0, 0, 0))
|
||||
self.surface_b.set_alpha(200)
|
||||
self.l_pos = (0, 0)
|
||||
self.r_pos = (0, 0)
|
||||
self.t_pos = (0, 0)
|
||||
self.b_pos = (0, 0)
|
||||
|
||||
def update_positions(self, position: [int, int]):
|
||||
new_position = (position[0] - self.width / 2 + 16, position[1] - self.height / 2 + 16)
|
||||
self.l_pos = new_position
|
||||
self.r_pos = (new_position[0] + (self.width + self.agent_vision) / 2, new_position[1])
|
||||
self.t_pos = (new_position[0] + (self.width - self.agent_vision) / 2, new_position[1])
|
||||
self.b_pos = (new_position[0] + (self.width - self.agent_vision) / 2,
|
||||
new_position[1] + (self.height + self.agent_vision) / 2)
|
@ -18,7 +18,7 @@ class EntityLayer:
|
||||
def remove_entity(self, pos):
|
||||
self.tiles[pos[1]][pos[0]] = None
|
||||
|
||||
def get_entity(self, pos) -> int:
|
||||
def get_entity(self, pos):
|
||||
return self.tiles[pos[1]][pos[0]]
|
||||
|
||||
def is_colliding(self, pos):
|
||||
|
@ -28,6 +28,9 @@ class Processor:
|
||||
def process(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class World:
|
||||
"""A World object keeps track of all Entities, Components, and Processors.
|
||||
@ -46,6 +49,14 @@ class World:
|
||||
self.process_times = {}
|
||||
self._process = self._timed_process
|
||||
|
||||
@property
|
||||
def processors(self):
|
||||
return self._processors
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
return self._entities
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.get_component.cache_clear()
|
||||
self.get_components.cache_clear()
|
||||
|
@ -1,4 +1,9 @@
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.resource_component import ResourceComponent
|
||||
from survival.entity_layer import EntityLayer
|
||||
from survival.esper import World
|
||||
from survival.graph_search import graph_search
|
||||
from survival.settings import AGENT_VISION_RANGE
|
||||
from survival.tile_layer import TileLayer
|
||||
|
||||
|
||||
@ -23,10 +28,55 @@ class GameMap:
|
||||
self.entity_layer.remove_entity(pos)
|
||||
|
||||
def get_entity(self, pos) -> int:
|
||||
if not self.in_bounds(pos):
|
||||
return None
|
||||
return self.entity_layer.get_entity(pos)
|
||||
|
||||
def is_colliding(self, pos):
|
||||
return pos[0] < 0 or pos[0] >= self.width or pos[1] < 0 or pos[1] >= self.height or self.entity_layer.is_colliding(pos)
|
||||
return not self.in_bounds(pos) or self.entity_layer.is_colliding(pos)
|
||||
|
||||
def in_bounds(self, pos):
|
||||
return 0 <= pos[0] < self.width and 0 <= pos[1] < self.height
|
||||
|
||||
def get_cost(self, pos):
|
||||
return self.tile_layer.get_cost(pos)
|
||||
|
||||
def find_nearby_resources(self, world: World, player: int, position: PositionComponent, search_range: int = 5):
|
||||
entity_position = position.grid_position
|
||||
|
||||
x_range = [entity_position[0] - search_range, entity_position[0] + search_range]
|
||||
y_range = [entity_position[1] - search_range, entity_position[1] + search_range]
|
||||
|
||||
# Check if range is not out of map bounds
|
||||
if x_range[0] < 0:
|
||||
x_range[0] = 0
|
||||
if x_range[1] >= self.width:
|
||||
x_range[1] = self.width - 1
|
||||
if y_range[0] < 0:
|
||||
y_range[0] = 0
|
||||
if y_range[1] >= self.height:
|
||||
y_range[1] = self.height - 1
|
||||
|
||||
found_resources = []
|
||||
|
||||
for y in range(y_range[0], y_range[1]):
|
||||
for x in range(x_range[0], x_range[1]):
|
||||
ent = self.get_entity([x, y])
|
||||
if ent == player:
|
||||
continue
|
||||
if ent is not None and world.has_component(ent, ResourceComponent):
|
||||
res_position = world.component_for_entity(ent, PositionComponent).grid_position
|
||||
path, cost = graph_search(self, position, tuple(res_position), world)
|
||||
found_resources.append([ent, path, cost])
|
||||
|
||||
return found_resources
|
||||
|
||||
def find_nearest_resource(self, world: World, player: int, position: PositionComponent):
|
||||
resources = self.find_nearby_resources(world, player, position, AGENT_VISION_RANGE)
|
||||
|
||||
nearest = None
|
||||
for resource in resources:
|
||||
if nearest is None or resource[2] < nearest[2]:
|
||||
nearest = resource
|
||||
|
||||
return nearest
|
||||
|
@ -1,31 +1,41 @@
|
||||
from survival.components.OnCollisionComponent import OnCollisionComponent
|
||||
from survival.components.camera_target_component import CameraTargetComponent
|
||||
from survival.components.consumption_component import ConsumptionComponent
|
||||
from survival.components.input_component import InputComponent
|
||||
from survival.components.inventory_component import InventoryComponent
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.components.movement_component import MovementComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.sprite_component import SpriteComponent
|
||||
from survival.components.time_component import TimeComponent
|
||||
from survival.systems.automation_system import AutomationComponent
|
||||
from survival.components.vision_component import VisionComponent
|
||||
from survival.generators.resource_type import ResourceType
|
||||
from survival.settings import PLAYER_START_POSITION, STARTING_RESOURCES_AMOUNT
|
||||
|
||||
|
||||
class PlayerGenerator:
|
||||
|
||||
def create_player(self, world, game_map):
|
||||
player = world.create_entity()
|
||||
pos = PositionComponent([0, 0], [0, 0])
|
||||
pos = PositionComponent([PLAYER_START_POSITION[0] * 32, PLAYER_START_POSITION[1] * 32],
|
||||
PLAYER_START_POSITION)
|
||||
world.add_component(player, pos)
|
||||
world.add_component(player, MovementComponent())
|
||||
world.add_component(player, InputComponent())
|
||||
world.add_component(player, OnCollisionComponent())
|
||||
world.add_component(player, InventoryComponent())
|
||||
inv = InventoryComponent()
|
||||
for resource in ResourceType:
|
||||
inv.add_item(resource, STARTING_RESOURCES_AMOUNT)
|
||||
world.add_component(player, ConsumptionComponent(inv.total_items_count()))
|
||||
world.add_component(player, inv)
|
||||
camera_target = CameraTargetComponent(pos)
|
||||
world.add_component(player, camera_target)
|
||||
world.add_component(player, AutomationComponent())
|
||||
# world.add_component(player, AutomationComponent())
|
||||
game_map.add_entity(player, pos)
|
||||
sprite = SpriteComponent('stevenson.png')
|
||||
sprite.set_scale(1)
|
||||
world.add_component(player, sprite)
|
||||
world.add_component(player, TimeComponent(0, 0, 0, 0))
|
||||
world.add_component(player, TimeComponent())
|
||||
world.add_component(player, VisionComponent())
|
||||
world.add_component(player, LearningComponent())
|
||||
|
||||
return player
|
||||
|
@ -3,23 +3,24 @@ import random
|
||||
from survival import GameMap
|
||||
from survival.components.OnCollisionComponent import OnCollisionComponent
|
||||
from survival.components.inventory_component import InventoryComponent
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.resource_component import ResourceComponent
|
||||
from survival.components.sprite_component import SpriteComponent
|
||||
from survival.decision_tree import DecisionTree
|
||||
from survival.esper import World
|
||||
from survival.generators.resource_type import ResourceType
|
||||
from survival.settings import RESOURCES_AMOUNT
|
||||
from survival.settings import RESOURCES_AMOUNT, PLAYER_START_POSITION
|
||||
|
||||
|
||||
class ResourceGenerator:
|
||||
resources_amount = 0
|
||||
|
||||
def __init__(self, world, game_map):
|
||||
self.world = world
|
||||
self.map = game_map
|
||||
self.decision_tree = DecisionTree()
|
||||
self.built_tree = self.decision_tree.build(10)
|
||||
|
||||
def generate_resources(self, player: int):
|
||||
ResourceGenerator.resources_amount = RESOURCES_AMOUNT
|
||||
for x in range(RESOURCES_AMOUNT):
|
||||
obj = self.world.create_entity()
|
||||
sprites = {
|
||||
@ -34,7 +35,7 @@ class ResourceGenerator:
|
||||
resource_type = random.choice(list(ResourceType))
|
||||
sprite = SpriteComponent(sprites[resource_type])
|
||||
col = OnCollisionComponent()
|
||||
col.addCallback(self.remove_resource, world=self.world, game_map=self.map, resource_ent=obj, player=player, decision_tree=self.decision_tree)
|
||||
col.addCallback(self.remove_resource, world=self.world, game_map=self.map, resource_ent=obj, player=player)
|
||||
self.world.add_component(obj, pos)
|
||||
self.world.add_component(obj, sprite)
|
||||
self.world.add_component(obj, col)
|
||||
@ -43,17 +44,25 @@ class ResourceGenerator:
|
||||
|
||||
def get_empty_grid_position(self):
|
||||
free_pos = [random.randrange(self.map.width), random.randrange(self.map.height)]
|
||||
while self.map.is_colliding(free_pos):
|
||||
while self.map.is_colliding(free_pos) or (
|
||||
free_pos[0] == PLAYER_START_POSITION[0] and free_pos[1] == PLAYER_START_POSITION[1]):
|
||||
free_pos = [random.randrange(self.map.width), random.randrange(self.map.height)]
|
||||
return free_pos
|
||||
|
||||
@staticmethod
|
||||
def remove_resource(world: World, game_map: GameMap, resource_ent: int, player: int, decision_tree: DecisionTree):
|
||||
def remove_resource(world: World, game_map: GameMap, resource_ent: int, player: int):
|
||||
pos = world.component_for_entity(resource_ent, PositionComponent)
|
||||
resource = world.component_for_entity(resource_ent, ResourceComponent)
|
||||
inventory = world.component_for_entity(player, InventoryComponent)
|
||||
answer = decision_tree.predict_answer(resource)
|
||||
# print(answer)
|
||||
inventory.add_item(ResourceType.get_from_string(answer), 1)
|
||||
inventory.add_item(resource.resource_type, 1)
|
||||
game_map.remove_entity(pos.grid_position)
|
||||
world.delete_entity(resource_ent, immediate=True)
|
||||
if world.has_component(player, LearningComponent):
|
||||
learning = world.component_for_entity(player, LearningComponent)
|
||||
learning.reward = 10
|
||||
learning.score += 1
|
||||
ResourceGenerator.resources_amount -= 1
|
||||
if ResourceGenerator.resources_amount == 0:
|
||||
learning.reward += 50
|
||||
learning.done = True
|
||||
|
||||
|
@ -1,29 +1,106 @@
|
||||
from survival import esper
|
||||
from survival import esper, PlayerGenerator, ResourceGenerator, SCREEN_WIDTH, SCREEN_HEIGHT, GameMap, \
|
||||
Camera
|
||||
from survival.components.consumption_component import ConsumptionComponent
|
||||
from survival.components.direction_component import DirectionChangeComponent
|
||||
from survival.components.inventory_component import InventoryComponent
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.components.moving_component import MovingComponent
|
||||
from survival.components.pathfinding_component import PathfindingComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.resource_component import ResourceComponent
|
||||
from survival.components.time_component import TimeComponent
|
||||
from survival.esper import World
|
||||
from survival.generators.resource_type import ResourceType
|
||||
from survival.settings import PLAYER_START_POSITION, STARTING_RESOURCES_AMOUNT
|
||||
from survival.systems.automation_system import AutomationSystem
|
||||
from survival.systems.camera_system import CameraSystem
|
||||
from survival.systems.collection_system import ResourceCollectionSystem
|
||||
from survival.systems.collision_system import CollisionSystem
|
||||
from survival.systems.consumption_system import ConsumptionSystem
|
||||
from survival.systems.direction_system import DirectionSystem
|
||||
from survival.systems.draw_system import DrawSystem
|
||||
from survival.systems.input_system import InputSystem
|
||||
from survival.systems.movement_system import MovementSystem
|
||||
from survival.systems.pathfinding_movement_system import PathfindingMovementSystem
|
||||
from survival.systems.neural_system import NeuralSystem
|
||||
from survival.systems.time_system import TimeSystem
|
||||
from survival.systems.vision_system import VisionSystem
|
||||
|
||||
|
||||
class WorldGenerator:
|
||||
def __init__(self, win, callback):
|
||||
self.win = win
|
||||
self.callback = callback
|
||||
self.world: World = esper.World()
|
||||
self.game_map: GameMap = GameMap(int(SCREEN_WIDTH / 32) * 2, 2 * int(SCREEN_HEIGHT / 32) + 1)
|
||||
self.camera = Camera(self.game_map.width * 32, self.game_map.height * 32, self.win)
|
||||
self.resource_generator: ResourceGenerator = ResourceGenerator(self.world, self.game_map)
|
||||
self.player: int = -1
|
||||
|
||||
def create_world(self, camera, game_map):
|
||||
world = esper.World()
|
||||
world.add_processor(InputSystem(camera, game_map))
|
||||
world.add_processor(CameraSystem(camera))
|
||||
world.add_processor(MovementSystem(game_map), priority=2)
|
||||
world.add_processor(CollisionSystem(game_map), priority=3)
|
||||
world.add_processor(DrawSystem(camera))
|
||||
world.add_processor(ResourceCollectionSystem(), priority=1)
|
||||
world.add_processor(TimeSystem())
|
||||
world.add_processor(AutomationSystem(game_map))
|
||||
world.add_processor(PathfindingMovementSystem(game_map), priority=4)
|
||||
world.add_processor(DirectionSystem())
|
||||
def create_world(self):
|
||||
self.world.add_processor(InputSystem(self.camera, self.game_map))
|
||||
self.world.add_processor(CameraSystem(self.camera))
|
||||
self.world.add_processor(MovementSystem(self.game_map), priority=20)
|
||||
self.world.add_processor(CollisionSystem(self.game_map), priority=30)
|
||||
self.world.add_processor(NeuralSystem(self.game_map, self.callback), priority=50)
|
||||
self.world.add_processor(DrawSystem(self.camera))
|
||||
self.world.add_processor(TimeSystem())
|
||||
self.world.add_processor(AutomationSystem(self.game_map))
|
||||
# self.world.add_processor(PathfindingMovementSystem(self.game_map), priority=40)
|
||||
self.world.add_processor(DirectionSystem())
|
||||
self.world.add_processor(ConsumptionSystem(self.callback))
|
||||
self.world.add_processor(VisionSystem(self.camera))
|
||||
|
||||
return world
|
||||
self.player = PlayerGenerator().create_player(self.world, self.game_map)
|
||||
self.world.get_processor(DrawSystem).initialize_interface(
|
||||
self.world.component_for_entity(self.player, InventoryComponent))
|
||||
|
||||
# BuildingGenerator().create_home(self.world, self.game_map)
|
||||
self.resource_generator.generate_resources(self.player)
|
||||
return self.game_map, self.world, self.camera
|
||||
|
||||
def reset_world(self):
|
||||
for processor in self.world.processors:
|
||||
processor.reset()
|
||||
|
||||
self.reset_player()
|
||||
self.reset_resources()
|
||||
|
||||
def reset_resources(self):
|
||||
for entity in self.world.entities:
|
||||
if self.world.has_component(entity, ResourceComponent):
|
||||
self.game_map.remove_entity(self.world.component_for_entity(entity, PositionComponent).grid_position)
|
||||
self.world.delete_entity(entity)
|
||||
continue
|
||||
self.resource_generator.generate_resources(self.player)
|
||||
|
||||
def reset_player(self):
|
||||
self.world.remove_component(self.player, TimeComponent)
|
||||
self.world.add_component(self.player, TimeComponent())
|
||||
|
||||
inv = self.world.component_for_entity(self.player, InventoryComponent)
|
||||
inv.clear()
|
||||
for resource in ResourceType:
|
||||
inv.add_item(resource, STARTING_RESOURCES_AMOUNT)
|
||||
|
||||
if self.world.has_component(self.player, ConsumptionComponent):
|
||||
self.world.remove_component(self.player, ConsumptionComponent)
|
||||
self.world.add_component(self.player, ConsumptionComponent(inv.total_items_count()))
|
||||
|
||||
pos = self.world.component_for_entity(self.player, PositionComponent)
|
||||
old_pos = pos.grid_position
|
||||
|
||||
self.world.remove_component(self.player, PositionComponent)
|
||||
self.world.add_component(self.player,
|
||||
PositionComponent([PLAYER_START_POSITION[0] * 32, PLAYER_START_POSITION[1] * 32],
|
||||
PLAYER_START_POSITION))
|
||||
|
||||
self.game_map.move_entity(old_pos, pos.grid_position)
|
||||
|
||||
if self.world.has_component(self.player, MovingComponent):
|
||||
self.world.remove_component(self.player, MovingComponent)
|
||||
if self.world.has_component(self.player, DirectionChangeComponent):
|
||||
self.world.remove_component(self.player, DirectionChangeComponent)
|
||||
if self.world.has_component(self.player, PathfindingComponent):
|
||||
self.world.remove_component(self.player, PathfindingComponent)
|
||||
if self.world.has_component(self.player, LearningComponent):
|
||||
learning = self.world.component_for_entity(self.player, LearningComponent)
|
||||
learning.reset()
|
||||
|
@ -2,7 +2,8 @@ from enum import Enum
|
||||
from queue import PriorityQueue
|
||||
from typing import Tuple, List
|
||||
|
||||
from survival import GameMap
|
||||
from survival.components.direction_component import DirectionChangeComponent
|
||||
from survival.components.moving_component import MovingComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.resource_component import ResourceComponent
|
||||
from survival.enums import Direction
|
||||
@ -14,6 +15,33 @@ class Action(Enum):
|
||||
ROTATE_RIGHT = 1
|
||||
MOVE = 2
|
||||
|
||||
@staticmethod
|
||||
def from_array(action):
|
||||
if action[0] == 1:
|
||||
return Action.MOVE
|
||||
if action[1] == 1:
|
||||
return Action.ROTATE_LEFT
|
||||
if action[2] == 1:
|
||||
return Action.ROTATE_RIGHT
|
||||
raise Exception("Unknown action.")
|
||||
|
||||
@staticmethod
|
||||
def perform(world, entity, action):
|
||||
if world.has_component(entity, MovingComponent):
|
||||
raise Exception(f"Entity was already moving. Could not perform action: {action}")
|
||||
if world.has_component(entity, DirectionChangeComponent):
|
||||
raise Exception(f"Entity was already rotating. Could not perform action: {action}")
|
||||
|
||||
if action == Action.ROTATE_LEFT:
|
||||
world.add_component(entity, DirectionChangeComponent(
|
||||
Direction.rotate_left(world.component_for_entity(entity, PositionComponent).direction)))
|
||||
elif action == Action.ROTATE_RIGHT:
|
||||
world.add_component(entity, DirectionChangeComponent(
|
||||
Direction.rotate_right(world.component_for_entity(entity, PositionComponent).direction)))
|
||||
else:
|
||||
world.add_component(entity, MovingComponent())
|
||||
return action
|
||||
|
||||
|
||||
class State:
|
||||
def __init__(self, position: Tuple[int, int], direction: Direction):
|
||||
@ -40,7 +68,7 @@ def get_moved_position(position: Tuple[int, int], direction: Direction):
|
||||
return position[0] + vector[0], position[1] + vector[1]
|
||||
|
||||
|
||||
def get_states(state: State, game_map: GameMap, world: World) -> List[Tuple[Action, State, int]]:
|
||||
def get_states(state: State, game_map, world: World) -> List[Tuple[Action, State, int]]:
|
||||
states = list()
|
||||
|
||||
states.append((Action.ROTATE_LEFT, State(state.position, state.direction.rotate_left(state.direction)), 1))
|
||||
@ -58,23 +86,25 @@ def get_states(state: State, game_map: GameMap, world: World) -> List[Tuple[Acti
|
||||
|
||||
|
||||
def build_path(node: Node):
|
||||
cost = 0
|
||||
actions = [node.action]
|
||||
parent = node.parent
|
||||
|
||||
while parent is not None:
|
||||
if parent.action is not None:
|
||||
actions.append(parent.action)
|
||||
cost += parent.cost
|
||||
parent = parent.parent
|
||||
|
||||
actions.reverse()
|
||||
return actions
|
||||
return actions, cost
|
||||
|
||||
|
||||
def heuristic(new_node: Node, goal: Tuple[int, int]):
|
||||
return abs(new_node.state.position[0] - goal[0]) + abs(new_node.state.position[1] - goal[1])
|
||||
|
||||
|
||||
def graph_search(game_map: GameMap, start: PositionComponent, goal: tuple, world: World):
|
||||
def graph_search(game_map, start: PositionComponent, goal: tuple, world: World):
|
||||
fringe = PriorityQueue()
|
||||
explored = list()
|
||||
|
||||
@ -88,7 +118,7 @@ def graph_search(game_map: GameMap, start: PositionComponent, goal: tuple, world
|
||||
while True:
|
||||
# No solutions found
|
||||
if fringe.empty():
|
||||
return []
|
||||
return [], 0
|
||||
|
||||
node = fringe.get()
|
||||
node_priority = node[0]
|
||||
|
@ -4,8 +4,11 @@ import pygame
|
||||
|
||||
|
||||
class Image:
|
||||
def __init__(self, filename, pos=(0, 0), scale=1):
|
||||
def __init__(self, filename='', pos=(0, 0), scale=1, surface=None):
|
||||
if surface is None:
|
||||
self.texture = pygame.image.load(os.path.join('..', 'assets', filename)).convert_alpha()
|
||||
else:
|
||||
self.texture = surface
|
||||
self.image = self.texture
|
||||
self.origin = (0, 0)
|
||||
self.pos = pos
|
||||
|
112
survival/learning_utils.py
Normal file
112
survival/learning_utils.py
Normal file
@ -0,0 +1,112 @@
|
||||
import numpy as np
|
||||
from IPython import display
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.enums import Direction
|
||||
from survival.graph_search import Action
|
||||
|
||||
|
||||
class LearningUtils:
|
||||
def __init__(self):
|
||||
self.plot_scores = []
|
||||
self.plot_mean_scores = []
|
||||
self.total_score = 0
|
||||
self.last_actions: [Action, [int, int]] = []
|
||||
|
||||
def add_scores(self, learning: LearningComponent, games_count: int):
|
||||
self.plot_scores.append(learning.score)
|
||||
self.total_score += learning.score
|
||||
mean_score = self.total_score / games_count
|
||||
self.plot_mean_scores.append(mean_score)
|
||||
|
||||
def plot(self):
|
||||
display.clear_output(wait=True)
|
||||
display.display(plt.gcf())
|
||||
plt.clf()
|
||||
plt.title('Training...')
|
||||
plt.xlabel('Number of Games')
|
||||
plt.ylabel('Score')
|
||||
plt.plot(self.plot_scores)
|
||||
# plt.plot(self.plot_mean_scores)
|
||||
plt.ylim(ymin=0)
|
||||
plt.text(len(self.plot_scores) - 1, self.plot_scores[-1], str(self.plot_scores[-1]))
|
||||
# plt.text(len(self.plot_mean_scores) - 1, self.plot_mean_scores[-1], str(self.plot_mean_scores[-1]))
|
||||
plt.show(block=False)
|
||||
plt.pause(.1)
|
||||
|
||||
def append_action(self, action: Action, pos: PositionComponent):
|
||||
self.last_actions.append([action, pos.grid_position])
|
||||
|
||||
def check_last_actions(self, learning):
|
||||
"""
|
||||
Checks if all the last five actions were repeated and imposes the potential penalty.
|
||||
:param learning:
|
||||
"""
|
||||
if len(self.last_actions) > 5:
|
||||
self.last_actions.pop(0)
|
||||
|
||||
last_action: [Action, [int, int]] = self.last_actions[0]
|
||||
last_grid_pos: [int, int] = last_action[1]
|
||||
|
||||
rotations = 0
|
||||
collisions = 0
|
||||
for action in self.last_actions:
|
||||
if action != Action.MOVE:
|
||||
rotations += 1
|
||||
else:
|
||||
current_grid_pos = action[1]
|
||||
if current_grid_pos[0] == last_grid_pos[0] and current_grid_pos[1] == last_grid_pos[1]:
|
||||
collisions += 1
|
||||
|
||||
if rotations > 4 or collisions > 4:
|
||||
learning.reward -= 2
|
||||
|
||||
|
||||
def get_state(system, player, resource):
|
||||
pos: PositionComponent = system.world.component_for_entity(player, PositionComponent)
|
||||
if resource is None or resource[0] is None:
|
||||
res_l = False
|
||||
res_r = False
|
||||
res_u = False
|
||||
res_d = False
|
||||
else:
|
||||
resource_pos: PositionComponent = system.world.component_for_entity(resource[0], PositionComponent)
|
||||
res_l = resource_pos.grid_position[0] < pos.grid_position[0]
|
||||
res_r = resource_pos.grid_position[0] > pos.grid_position[0]
|
||||
res_u = resource_pos.grid_position[1] < pos.grid_position[1]
|
||||
res_d = resource_pos.grid_position[1] > pos.grid_position[1]
|
||||
|
||||
dir_l = pos.direction == Direction.LEFT
|
||||
dir_r = pos.direction == Direction.RIGHT
|
||||
dir_u = pos.direction == Direction.UP
|
||||
dir_d = pos.direction == Direction.DOWN
|
||||
|
||||
pos_l = [pos.grid_position[0] - 1, pos.grid_position[1]]
|
||||
pos_r = [pos.grid_position[0] + 1, pos.grid_position[1]]
|
||||
pos_u = [pos.grid_position[0], pos.grid_position[1] - 1]
|
||||
pos_d = [pos.grid_position[0], pos.grid_position[1] + 1]
|
||||
col_l = system.game_map.in_bounds(
|
||||
pos_l) # self.game_map.is_colliding(pos_l) and self.game_map.get_entity(pos_l) is None
|
||||
col_r = system.game_map.in_bounds(
|
||||
pos_r) # self.game_map.is_colliding(pos_r) and self.game_map.get_entity(pos_r) is None
|
||||
col_u = system.game_map.in_bounds(
|
||||
pos_u) # self.game_map.is_colliding(pos_u) and self.game_map.get_entity(pos_u) is None
|
||||
col_d = system.game_map.in_bounds(
|
||||
pos_d) # self.game_map.is_colliding(pos_d) and self.game_map.get_entity(pos_d) is None
|
||||
|
||||
state = [
|
||||
# Collision ahead
|
||||
(dir_r and col_r) or (dir_l and col_l) or (dir_u and col_u) or (dir_d and col_d),
|
||||
# Collision on the right
|
||||
(dir_u and col_r) or (dir_r and col_d) or (dir_d and col_l) or (dir_l and col_u),
|
||||
# Collision on the left
|
||||
(dir_u and col_l) or (dir_l and col_d) or (dir_d and col_r) or (dir_r and col_u),
|
||||
# Movement direction
|
||||
dir_l, dir_r, dir_u, dir_d,
|
||||
# Resource location
|
||||
res_l, res_r, res_u, res_d
|
||||
]
|
||||
|
||||
return np.array(state, dtype=int)
|
78
survival/model.py
Normal file
78
survival/model.py
Normal file
@ -0,0 +1,78 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
import torch.nn.functional as functional
|
||||
|
||||
|
||||
class LinearQNetwork(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size, pretrained=False):
|
||||
super().__init__()
|
||||
self.linear_one = nn.Linear(input_size, hidden_size)
|
||||
self.linear_two = nn.Linear(hidden_size, output_size)
|
||||
self.pretrained = pretrained
|
||||
|
||||
def forward(self, x):
|
||||
x = functional.relu(self.linear_one(x))
|
||||
x = self.linear_two(x)
|
||||
|
||||
return x
|
||||
|
||||
def save(self, file_name='model.pth'):
|
||||
model_directory = 'model'
|
||||
if not os.path.exists(model_directory):
|
||||
os.makedirs(model_directory)
|
||||
|
||||
file_path = os.path.join(model_directory, file_name)
|
||||
torch.save(self.state_dict(), file_path)
|
||||
|
||||
@staticmethod
|
||||
def load(input_size, hidden_size, output_size, file_name='model.pth'):
|
||||
model_directory = 'model'
|
||||
file_path = os.path.join(model_directory, file_name)
|
||||
if os.path.isfile(file_path):
|
||||
model = LinearQNetwork(input_size, hidden_size, output_size, True)
|
||||
model.load_state_dict(torch.load(file_path))
|
||||
model.eval()
|
||||
return model
|
||||
return LinearQNetwork(11, 256, 3)
|
||||
|
||||
|
||||
class QTrainer:
|
||||
def __init__(self, model, lr, gamma):
|
||||
self.model = model
|
||||
self.lr = lr
|
||||
self.gamma = gamma
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
|
||||
self.criterion = nn.MSELoss() # Mean squared error
|
||||
|
||||
def train_step(self, state, action, reward, next_state, done):
|
||||
state = torch.tensor(state, dtype=torch.float)
|
||||
next_state = torch.tensor(next_state, dtype=torch.float)
|
||||
action = torch.tensor(action, dtype=torch.long)
|
||||
reward = torch.tensor(reward, dtype=torch.float)
|
||||
|
||||
if len(state.shape) == 1:
|
||||
# reshape the state to make its values an (n, x) tuple
|
||||
state = torch.unsqueeze(state, 0)
|
||||
next_state = torch.unsqueeze(next_state, 0)
|
||||
action = torch.unsqueeze(action, 0)
|
||||
reward = torch.unsqueeze(reward, 0)
|
||||
done = (done,)
|
||||
|
||||
# Prediction based on simplified Bellman's equation
|
||||
# Predict Q values for current state
|
||||
prediction = self.model(state)
|
||||
target = prediction.clone()
|
||||
for idx in range(len(done)):
|
||||
Q = reward[idx]
|
||||
if not done[idx]:
|
||||
Q = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
|
||||
# set the target of the maximum value of the action to Q
|
||||
target[idx][torch.argmax(action).item()] = Q
|
||||
# Apply the loss function
|
||||
self.optimizer.zero_grad()
|
||||
loss = self.criterion(target, prediction)
|
||||
loss.backward()
|
||||
|
||||
self.optimizer.step()
|
BIN
survival/model/model.pth
Normal file
BIN
survival/model/model.pth
Normal file
Binary file not shown.
BIN
survival/model/model204games.pth
Normal file
BIN
survival/model/model204games.pth
Normal file
Binary file not shown.
BIN
survival/model/modeltrained.pth
Normal file
BIN
survival/model/modeltrained.pth
Normal file
Binary file not shown.
BIN
survival/model/modeltrained2.pth
Normal file
BIN
survival/model/modeltrained2.pth
Normal file
Binary file not shown.
BIN
survival/model/new_model120games.pth
Normal file
BIN
survival/model/new_model120games.pth
Normal file
Binary file not shown.
BIN
survival/model/newer_model120games.pth
Normal file
BIN
survival/model/newer_model120games.pth
Normal file
Binary file not shown.
@ -1,4 +1,7 @@
|
||||
SCREEN_WIDTH = 1000
|
||||
SCREEN_HEIGHT = 600
|
||||
RESOURCES_AMOUNT = 300
|
||||
DIRECTION_CHANGE_DELAY = 200
|
||||
RESOURCES_AMOUNT = 100
|
||||
DIRECTION_CHANGE_DELAY = 5
|
||||
PLAYER_START_POSITION = [20, 10]
|
||||
STARTING_RESOURCES_AMOUNT = 10
|
||||
AGENT_VISION_RANGE = 5
|
||||
|
@ -7,6 +7,8 @@ from survival.components.resource_component import ResourceComponent
|
||||
|
||||
class AutomationComponent:
|
||||
pass
|
||||
# def __init__(self):
|
||||
# self.resources = []
|
||||
|
||||
|
||||
class AutomationSystem(esper.Processor):
|
||||
|
@ -1,20 +0,0 @@
|
||||
from survival import esper
|
||||
from survival.components.direction_component import DirectionChangeComponent
|
||||
from survival.components.moving_component import MovingComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.graph_search import Action
|
||||
from survival.systems.pathfinding_movement_system import CollectingResourceComponent
|
||||
|
||||
|
||||
class ResourceCollectionSystem(esper.Processor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def process(self, dt):
|
||||
for ent, (collect, pos) in self.world.get_components(CollectingResourceComponent, PositionComponent):
|
||||
if self.world.has_component(ent, MovingComponent) or self.world.has_component(ent, DirectionChangeComponent):
|
||||
continue
|
||||
|
||||
if collect.action == Action.MOVE:
|
||||
self.world.remove_component(ent, CollectingResourceComponent)
|
||||
self.world.add_component(ent, MovingComponent())
|
@ -4,6 +4,7 @@ from survival import esper
|
||||
from survival.components.OnCollisionComponent import OnCollisionComponent
|
||||
from survival.components.moving_component import MovingComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.enums import Direction
|
||||
|
||||
|
||||
@ -18,7 +19,6 @@ class CollisionSystem(esper.Processor):
|
||||
continue
|
||||
|
||||
moving.checked_collision = True
|
||||
|
||||
vector = Direction.get_vector(pos.direction)
|
||||
moving.target = tuple(map(operator.add, vector, pos.grid_position))
|
||||
moving.direction_vector = vector
|
||||
|
30
survival/systems/consumption_system.py
Normal file
30
survival/systems/consumption_system.py
Normal file
@ -0,0 +1,30 @@
|
||||
from survival import esper
|
||||
from survival.components.consumption_component import ConsumptionComponent
|
||||
from survival.components.inventory_component import InventoryComponent
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.generators.resource_type import ResourceType
|
||||
|
||||
|
||||
class ConsumptionSystem(esper.Processor):
|
||||
def __init__(self, callback):
|
||||
self.callback = callback
|
||||
|
||||
def process(self, dt):
|
||||
for ent, (cons, inventory) in self.world.get_components(ConsumptionComponent, InventoryComponent):
|
||||
cons.timer -= dt
|
||||
if cons.timer > 0:
|
||||
continue
|
||||
cons.timer = cons.timer_value
|
||||
|
||||
if self.world.has_component(ent, LearningComponent):
|
||||
# If no item was picked up
|
||||
if cons.last_inventory_state == inventory.total_items_count():
|
||||
learning: LearningComponent = self.world.component_for_entity(ent, LearningComponent)
|
||||
learning.reward = -10
|
||||
learning.done = True
|
||||
cons.last_inventory_state = inventory.total_items_count()
|
||||
else:
|
||||
if inventory.has_item(ResourceType.FOOD):
|
||||
inventory.remove_item(ResourceType.FOOD, 1)
|
||||
else:
|
||||
self.callback()
|
@ -7,10 +7,10 @@ from survival.user_interface import UserInterface
|
||||
class DrawSystem(esper.Processor):
|
||||
def __init__(self, camera):
|
||||
self.camera = camera
|
||||
self.ui = None
|
||||
self.ui = UserInterface(self.camera.window)
|
||||
|
||||
def initialize_interface(self, inventory):
|
||||
self.ui = UserInterface(self.camera.window, inventory)
|
||||
self.ui.load_inventory(inventory)
|
||||
|
||||
def process(self, dt):
|
||||
for ent, (sprite, pos) in self.world.get_components(SpriteComponent, PositionComponent):
|
||||
|
@ -24,7 +24,7 @@ class InputSystem(esper.Processor):
|
||||
if not self.world.has_component(ent, PathfindingComponent):
|
||||
target_ent = self.game_map.get_entity([int(pos[0] / 32), int(pos[1]/ 32)])
|
||||
if target_ent is not None and self.world.has_component(target_ent, ResourceComponent):
|
||||
self.world.add_component(ent, PathfindingComponent(pos, True))
|
||||
self.world.add_component(ent, PathfindingComponent(pos))
|
||||
else:
|
||||
self.world.add_component(ent, PathfindingComponent(pos))
|
||||
|
||||
|
@ -13,11 +13,13 @@ class MovementSystem(esper.Processor):
|
||||
for ent, (mov, pos, moving, sprite) in self.world.get_components(MovementComponent, PositionComponent,
|
||||
MovingComponent,
|
||||
SpriteComponent):
|
||||
cost = self.map.get_cost(moving.target)
|
||||
pos.position[0] += moving.direction_vector[0] * mov.speed * dt / 100 / cost
|
||||
pos.position[1] += moving.direction_vector[1] * mov.speed * dt / 100 / cost
|
||||
|
||||
if abs(moving.target[0] * 32 - pos.position[0]) < 0.1 * mov.speed and abs(
|
||||
pos.position[1] - moving.target[1] * 32) < 0.1 * mov.speed:
|
||||
# cost = self.map.get_cost(moving.target)
|
||||
# pos.position[0] += moving.direction_vector[0] * mov.speed * dt / 100 / cost
|
||||
# pos.position[1] += moving.direction_vector[1] * mov.speed * dt / 100 / cost
|
||||
#
|
||||
# if abs(moving.target[0] * 32 - pos.position[0]) < 1 * mov.speed and abs(
|
||||
# pos.position[1] - moving.target[1] * 32) < 1 * mov.speed:
|
||||
# pos.position = [moving.target[0] * 32, moving.target[1] * 32]
|
||||
# self.world.remove_component(ent, MovingComponent)
|
||||
pos.position = [moving.target[0] * 32, moving.target[1] * 32]
|
||||
self.world.remove_component(ent, MovingComponent)
|
||||
|
120
survival/systems/neural_system.py
Normal file
120
survival/systems/neural_system.py
Normal file
@ -0,0 +1,120 @@
|
||||
import random
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
|
||||
from survival import esper, GameMap
|
||||
from survival.components.direction_component import DirectionChangeComponent
|
||||
from survival.components.inventory_component import InventoryComponent
|
||||
from survival.components.moving_component import MovingComponent
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.learning_component import LearningComponent
|
||||
from survival.components.time_component import TimeComponent
|
||||
from survival.graph_search import Action
|
||||
from survival.learning_utils import get_state, LearningUtils
|
||||
from survival.model import LinearQNetwork, QTrainer
|
||||
|
||||
MAX_MEMORY = 100_000
|
||||
BATCH_SIZE = 1000
|
||||
LR = 0.001
|
||||
LEARN = True
|
||||
|
||||
|
||||
class NeuralSystem(esper.Processor):
|
||||
def __init__(self, game_map: GameMap, callback):
|
||||
self.game_map = game_map
|
||||
self.reset_game = callback
|
||||
self.n_games = 0 # number of games played
|
||||
self.starting_epsilon = 100
|
||||
self.epsilon = 0 # controlls the randomness
|
||||
self.gamma = 0.9 # discount rate
|
||||
self.memory = deque(maxlen=MAX_MEMORY) # exceeding memory removes the left elements to make more space
|
||||
self.model = LinearQNetwork.load(11, 256, 3)
|
||||
if self.model.pretrained:
|
||||
self.starting_epsilon = -1
|
||||
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
|
||||
self.utils = LearningUtils()
|
||||
|
||||
def remember(self, state, action, reward, next_state, done):
|
||||
self.memory.append((state, action, reward, next_state, done))
|
||||
|
||||
def train_short_memory(self, state, action, reward, next_state, done):
|
||||
self.trainer.train_step(state, action, reward, next_state, done)
|
||||
|
||||
def train_long_memory(self):
|
||||
if len(self.memory) > BATCH_SIZE:
|
||||
mini_sample = random.sample(self.memory, BATCH_SIZE)
|
||||
else:
|
||||
mini_sample = self.memory
|
||||
states, actions, rewards, next_states, dones = zip(*mini_sample)
|
||||
self.trainer.train_step(states, actions, rewards, next_states, dones)
|
||||
|
||||
def get_action(self, state):
|
||||
self.epsilon = self.starting_epsilon - self.n_games
|
||||
final_move = [0, 0, 0]
|
||||
if random.randint(0, 200) < self.epsilon:
|
||||
move = random.randint(0, 2)
|
||||
final_move[move] = 1
|
||||
else:
|
||||
state_zero = torch.tensor(state, dtype=torch.float)
|
||||
prediction = self.model(state_zero)
|
||||
move = torch.argmax(prediction).item()
|
||||
final_move[move] = 1
|
||||
|
||||
return final_move
|
||||
|
||||
def process(self, dt):
|
||||
for ent, (pos, inventory, time, learning) in self.world.get_components(PositionComponent, InventoryComponent,
|
||||
TimeComponent, LearningComponent):
|
||||
if not learning.made_step:
|
||||
learning.reset()
|
||||
|
||||
# Get the closest resource | [entity, path, cost]
|
||||
resource: [int, list, int] = self.game_map.find_nearest_resource(self.world, ent, pos)
|
||||
|
||||
# Get current entity state
|
||||
old_state = get_state(self, ent, resource)
|
||||
# Predict the action
|
||||
action = self.get_action(old_state)
|
||||
# Save the action
|
||||
learning.load_step(old_state, action, resource)
|
||||
# Perform the action
|
||||
act = Action.perform(self.world, ent, Action.from_array(action))
|
||||
self.utils.append_action(act, pos)
|
||||
continue
|
||||
|
||||
# Wait for the action to complete
|
||||
if self.world.has_component(ent, DirectionChangeComponent) or self.world.has_component(ent,
|
||||
MovingComponent):
|
||||
continue
|
||||
|
||||
self.utils.check_last_actions(learning)
|
||||
|
||||
resource = learning.resource
|
||||
if resource is None or not self.world.entity_exists(resource[0]):
|
||||
# Find a new resource if no resource was found or the last one was consumed
|
||||
resource = self.game_map.find_nearest_resource(self.world, ent, pos)
|
||||
|
||||
# Get new state
|
||||
new_state = get_state(self, ent, resource)
|
||||
# Train agent's memory
|
||||
self.train_short_memory(learning.old_state, learning.action, learning.reward, new_state, learning.done)
|
||||
self.remember(learning.old_state, learning.action, learning.reward, new_state, learning.done)
|
||||
|
||||
learning.made_step = False
|
||||
|
||||
if learning.done:
|
||||
self.n_games += 1
|
||||
if LEARN:
|
||||
self.train_long_memory()
|
||||
if learning.score > learning.record:
|
||||
learning.record = learning.score
|
||||
if LEARN:
|
||||
self.model.save()
|
||||
|
||||
print('Game', self.n_games, 'Score', learning.score, 'Record', learning.record)
|
||||
self.utils.add_scores(learning, self.n_games)
|
||||
learning.score = 0
|
||||
self.utils.plot()
|
||||
|
||||
self.reset_game()
|
@ -8,11 +8,6 @@ from survival.graph_search import graph_search, Action
|
||||
from survival.systems.input_system import PathfindingComponent
|
||||
|
||||
|
||||
class CollectingResourceComponent:
|
||||
def __init__(self, action):
|
||||
self.action = action
|
||||
|
||||
|
||||
class PathfindingMovementSystem(esper.Processor):
|
||||
def __init__(self, game_map):
|
||||
self.game_map = game_map
|
||||
@ -21,17 +16,12 @@ class PathfindingMovementSystem(esper.Processor):
|
||||
for ent, (pos, pathfinding, movement) in self.world.get_components(PositionComponent, PathfindingComponent,
|
||||
MovementComponent):
|
||||
if pathfinding.path is None:
|
||||
pathfinding.path = graph_search(self.game_map, pos, pathfinding.target_grid_pos, self.world)
|
||||
pathfinding.path, cost = graph_search(self.game_map, pos, pathfinding.target_grid_pos, self.world)
|
||||
|
||||
if len(pathfinding.path) < 1:
|
||||
self.world.remove_component(ent, PathfindingComponent)
|
||||
continue
|
||||
|
||||
if pathfinding.searching_for_resource and len(pathfinding.path) == 1:
|
||||
self.world.add_component(ent, CollectingResourceComponent(pathfinding.path.pop(0)))
|
||||
self.world.remove_component(ent, PathfindingComponent)
|
||||
continue
|
||||
|
||||
if self.world.has_component(ent, MovingComponent) or self.world.has_component(ent, DirectionChangeComponent):
|
||||
continue
|
||||
|
||||
|
18
survival/systems/vision_system.py
Normal file
18
survival/systems/vision_system.py
Normal file
@ -0,0 +1,18 @@
|
||||
from survival import esper
|
||||
from survival.components.position_component import PositionComponent
|
||||
from survival.components.vision_component import VisionComponent
|
||||
|
||||
|
||||
class VisionSystem(esper.Processor):
|
||||
def __init__(self, camera):
|
||||
self.camera = camera
|
||||
|
||||
def process(self, dt):
|
||||
pos: PositionComponent
|
||||
vision: VisionComponent
|
||||
for ent, (pos, vision) in self.world.get_components(PositionComponent, VisionComponent):
|
||||
vision.update_positions(pos.position)
|
||||
self.camera.window.blit(vision.surface_l, self.camera.apply(vision.l_pos))
|
||||
self.camera.window.blit(vision.surface_r, self.camera.apply(vision.r_pos))
|
||||
self.camera.window.blit(vision.surface_t, self.camera.apply(vision.t_pos))
|
||||
self.camera.window.blit(vision.surface_b, self.camera.apply(vision.b_pos))
|
@ -7,13 +7,13 @@ from survival.image import Image
|
||||
|
||||
|
||||
class UserInterface:
|
||||
def __init__(self, window, inventory: InventoryComponent):
|
||||
def __init__(self, window):
|
||||
self.width = settings.SCREEN_WIDTH
|
||||
self.height = settings.SCREEN_HEIGHT
|
||||
self.window = window
|
||||
self.pos = (self.width - 240, 50)
|
||||
self.scale = 2
|
||||
self.inventory = inventory
|
||||
self.inventory: InventoryComponent = None
|
||||
self.images = {
|
||||
ResourceType.FOOD: Image('apple.png', self.pos, self.scale),
|
||||
ResourceType.WATER: Image('water.png', self.pos, self.scale),
|
||||
@ -26,6 +26,9 @@ class UserInterface:
|
||||
self.slot_image = Image('ui.png', self.pos, scale=2)
|
||||
self.font = pygame.font.SysFont('Comic Sans MS', 20)
|
||||
|
||||
def load_inventory(self, inventory: InventoryComponent):
|
||||
self.inventory = inventory
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user