AI-Project/survival/systems/neural_system.py

143 lines
6.0 KiB
Python
Raw Permalink Normal View History

2021-06-06 19:55:55 +02:00
import random
from collections import deque
import torch
from survival import esper, GameMap
2021-06-21 12:20:25 +02:00
from survival.ai.genetic_algorithm import GeneticAlgorithm
2021-06-06 19:55:55 +02:00
from survival.components.direction_component import DirectionChangeComponent
from survival.components.inventory_component import InventoryComponent
from survival.components.moving_component import MovingComponent
from survival.components.position_component import PositionComponent
from survival.components.learning_component import LearningComponent
from survival.components.time_component import TimeComponent
from survival.ai.graph_search import Action
from survival.ai.learning_utils import get_state, LearningUtils
from survival.ai.model import LinearQNetwork, QTrainer
2021-06-21 12:20:25 +02:00
from survival.settings import LEARN, MUTATE_NETWORKS
2021-06-06 19:55:55 +02:00
MAX_MEMORY = 100_000
BATCH_SIZE = 1000
class NeuralSystem(esper.Processor):
def __init__(self, game_map: GameMap, callback):
self.game_map = game_map
self.reset_game = callback
self.n_games = 0 # number of games played
2021-06-21 12:20:25 +02:00
if MUTATE_NETWORKS:
self.starting_epsilon = GeneticAlgorithm.GAMES_PER_NETWORK / 2
else:
self.starting_epsilon = 100
2021-06-06 19:55:55 +02:00
self.epsilon = 0 # controlls the randomness
self.gamma = 0.9 # discount rate
self.memory = deque(maxlen=MAX_MEMORY) # exceeding memory removes the left elements to make more space
2021-06-21 12:20:25 +02:00
self.model = None # self.model = LinearQNetwork.load(11, 256, 3)
self.trainer = None # QTrainer(self.model, lr=LR, gamma=self.gamma)
2021-06-06 19:55:55 +02:00
self.utils = LearningUtils()
2021-06-07 13:39:32 +02:00
self.best_action = None
2021-06-06 19:55:55 +02:00
2021-06-21 12:20:25 +02:00
def load_model(self, model: LinearQNetwork):
self.model = model
self.trainer = QTrainer(self.model, self.model.network_params['ratio'], self.gamma,
self.model.network_params['optimizer'])
self.utils = LearningUtils()
self.memory = deque(maxlen=MAX_MEMORY)
self.starting_epsilon = GeneticAlgorithm.GAMES_PER_NETWORK / 2
self.n_games = 0
2021-06-06 19:55:55 +02:00
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def train_short_memory(self, state, action, reward, next_state, done):
self.trainer.train_step(state, action, reward, next_state, done)
def train_long_memory(self):
if len(self.memory) > BATCH_SIZE:
mini_sample = random.sample(self.memory, BATCH_SIZE)
else:
mini_sample = self.memory
states, actions, rewards, next_states, dones = zip(*mini_sample)
self.trainer.train_step(states, actions, rewards, next_states, dones)
def get_action(self, state):
self.epsilon = self.starting_epsilon - self.n_games
final_move = [0, 0, 0]
if random.randint(0, 200) < self.epsilon:
move = random.randint(0, 2)
final_move[move] = 1
else:
state_zero = torch.tensor(state, dtype=torch.float)
prediction = self.model(state_zero)
move = torch.argmax(prediction).item()
final_move[move] = 1
return final_move
def process(self, dt):
for ent, (pos, inventory, time, learning) in self.world.get_components(PositionComponent, InventoryComponent,
TimeComponent, LearningComponent):
if not learning.made_step:
learning.reset()
2021-06-07 13:39:32 +02:00
self.best_action = None
2021-06-06 19:55:55 +02:00
# Get the closest resource | [entity, path, cost]
resource: [int, list, int] = self.game_map.find_nearest_resource(self.world, ent, pos)
2021-06-07 13:39:32 +02:00
if resource is not None:
# If resource was found get the best move chosen by A*
self.best_action = resource[1][0]
2021-06-06 19:55:55 +02:00
# Get current entity state
old_state = get_state(self, ent, resource)
# Predict the action
action = self.get_action(old_state)
# Save the action
learning.load_step(old_state, action, resource)
# Perform the action
act = Action.perform(self.world, ent, Action.from_array(action))
self.utils.append_action(act, pos)
2021-06-07 13:39:32 +02:00
# Add reward if chosen action was the best action
if act == self.best_action:
learning.reward += 1
2021-06-06 19:55:55 +02:00
continue
# Wait for the action to complete
if self.world.has_component(ent, DirectionChangeComponent) or self.world.has_component(ent,
MovingComponent):
continue
self.utils.check_last_actions(learning)
resource = learning.resource
if resource is None or not self.world.entity_exists(resource[0]):
# Find a new resource if no resource was found or the last one was consumed
resource = self.game_map.find_nearest_resource(self.world, ent, pos)
# Get new state
new_state = get_state(self, ent, resource)
# Train agent's memory
self.train_short_memory(learning.old_state, learning.action, learning.reward, new_state, learning.done)
self.remember(learning.old_state, learning.action, learning.reward, new_state, learning.done)
learning.made_step = False
if learning.done:
self.n_games += 1
if LEARN:
self.train_long_memory()
if learning.score > learning.record:
learning.record = learning.score
2021-06-21 12:20:25 +02:00
if LEARN and not MUTATE_NETWORKS:
2021-06-06 19:55:55 +02:00
self.model.save()
2021-06-21 12:20:25 +02:00
# print('Game', self.n_games, 'Score', learning.score, 'Record', learning.record)
2021-06-06 19:55:55 +02:00
self.utils.add_scores(learning, self.n_games)
2021-06-21 12:20:25 +02:00
self.model.scores.append(learning.score)
2021-06-06 19:55:55 +02:00
learning.score = 0
self.utils.plot()
self.reset_game()