From 78e9b3745d71fc933922c9ce01dcc0137c25d553 Mon Sep 17 00:00:00 2001 From: Kanewersa <30356293+Kanewersa@users.noreply.github.com> Date: Mon, 21 Jun 2021 12:20:25 +0200 Subject: [PATCH] Add genetic algorithm --- survival/__init__.py | 13 ++- survival/ai/genetic_algorithm.py | 84 ++++++++++++++ survival/ai/learning_utils.py | 13 ++- survival/ai/model.py | 58 +++++++--- survival/ai/optimizer.py | 147 +++++++++++++++++++++++++ survival/ai/test.py | 93 ++++++++++++++++ survival/game/user_interface.py | 8 +- survival/generators/tile_generator.py | 3 +- survival/generators/world_generator.py | 18 ++- survival/settings.py | 15 ++- survival/systems/collision_system.py | 1 + survival/systems/neural_system.py | 30 +++-- 12 files changed, 445 insertions(+), 38 deletions(-) create mode 100644 survival/ai/genetic_algorithm.py create mode 100644 survival/ai/optimizer.py create mode 100644 survival/ai/test.py diff --git a/survival/__init__.py b/survival/__init__.py index 06f14ce..4b17895 100644 --- a/survival/__init__.py +++ b/survival/__init__.py @@ -1,13 +1,15 @@ import pygame +from survival.ai.genetic_algorithm import GeneticAlgorithm from survival.components.inventory_component import InventoryComponent from survival.game.game_map import GameMap from survival.generators.building_generator import BuildingGenerator from survival.generators.player_generator import PlayerGenerator from survival.generators.resource_generator import ResourceGenerator from survival.generators.world_generator import WorldGenerator -from survival.settings import SCREEN_WIDTH, SCREEN_HEIGHT +from survival.settings import SCREEN_WIDTH, SCREEN_HEIGHT, MUTATE_NETWORKS, LEARN from survival.systems.draw_system import DrawSystem +from survival.systems.neural_system import NeuralSystem class Game: @@ -15,10 +17,17 @@ class Game: self.world_generator = WorldGenerator(win, self.reset) self.game_map, self.world, self.camera = self.world_generator.create_world() self.run = True + if LEARN and MUTATE_NETWORKS: + self.genetic_algorithm = GeneticAlgorithm(self.world.get_processor(NeuralSystem), self.finish_training) def reset(self): + if LEARN and MUTATE_NETWORKS: + self.genetic_algorithm.train() self.world_generator.reset_world() + def finish_training(self): + self.run = False + def update(self, ms): events = pygame.event.get() @@ -46,4 +55,4 @@ if __name__ == '__main__': game = Game() while game.run: - game.update(clock.tick(60)) + game.update(clock.tick(500)) diff --git a/survival/ai/genetic_algorithm.py b/survival/ai/genetic_algorithm.py new file mode 100644 index 0000000..3944642 --- /dev/null +++ b/survival/ai/genetic_algorithm.py @@ -0,0 +1,84 @@ +import sys + +from survival.ai.model import LinearQNetwork +from survival.ai.optimizer import Optimizer + + +class GeneticAlgorithm: + GAMES_PER_NETWORK = 40 + PLOTS_COUNTER = 0 + CURRENT_GENERATION = 1 + + def __init__(self, neural_system, callback): + self.callback = callback + self.logs_file = open('genetic_logs.txt', 'w') + self.original_stdout = sys.stdout + sys.stdout = self.logs_file + self.neural_system = neural_system + self.generations = 20 + self.population = 10 # Minimum 5 needed to allow breeding + self.nn_params = { + 'neurons': [128, 192, 256, 384, 512], + 'layers': [0, 1, 2, 3], + 'activation': ['relu', 'elu', 'tanh'], + 'ratio': [0.0007, 0.0009, 0.0011, 0.0013, 0.0015], + 'optimizer': ['RMSprop', 'Adam', 'SGD', 'Adagrad', 'Adadelta'], + } + self.optimizer = Optimizer(self.nn_params) + self.networks: list[LinearQNetwork] = self.optimizer.create_population(self.population) + self.finished = False + self.trained_counter = 0 + self.iterations = 0 + self.trained_generations = 0 + print('Started generation 1...') + self.change_network(self.networks[0]) + + def train(self): + if self.iterations < GeneticAlgorithm.GAMES_PER_NETWORK - 1: + self.iterations += 1 + return + self.iterations = 0 + print(f'Network score: {self.optimizer.fitness(self.networks[self.trained_counter])}') + self.trained_counter += 1 + + # If all networks in current population were trained + if self.trained_counter >= self.population: + # Get average score in current population + avg_score = self.calculate_average_score(self.networks) + print(f'Average population score: {avg_score}.') + + results_file = open('genetic_results.txt', 'w') + for network in self.networks: + results_file.write( + f'Network {network.id} params {network.network_params}. Avg score = {sum(network.scores) / len(network.scores)}\n') + results_file.close() + if self.trained_generations >= self.generations - 1: + # Sort the final population + self.networks = sorted(self.networks, key=lambda x: sum(x.scores) / len(x.scores), reverse=True) + self.finished = True + self.logs_file.close() + sys.stdout = self.original_stdout + self.callback() + return + + self.trained_generations += 1 + GeneticAlgorithm.CURRENT_GENERATION = self.trained_generations + 1 + print(f'Started generation {GeneticAlgorithm.CURRENT_GENERATION}...') + self.networks = self.optimizer.evolve(self.networks) + self.trained_counter = 0 + self.change_network(self.networks[self.trained_counter]) + + def calculate_average_score(self, networks): + sums = 0 + lengths = 0 + for network in networks: + sums += self.optimizer.fitness(network) + lengths += len(network.scores) + return sums / lengths + + def change_network(self, net): + GeneticAlgorithm.PLOTS_COUNTER += 1 + print(f"Changed network to {GeneticAlgorithm.PLOTS_COUNTER} {net.network_params}") + self.logs_file.flush() + net.id = GeneticAlgorithm.PLOTS_COUNTER + self.neural_system.load_model(net) diff --git a/survival/ai/learning_utils.py b/survival/ai/learning_utils.py index f08687f..3ebecbf 100644 --- a/survival/ai/learning_utils.py +++ b/survival/ai/learning_utils.py @@ -2,6 +2,8 @@ import numpy as np from IPython import display from matplotlib import pyplot as plt +from survival.settings import MUTATE_NETWORKS +from survival.ai.genetic_algorithm import GeneticAlgorithm from survival.components.learning_component import LearningComponent from survival.components.position_component import PositionComponent from survival.game.enums import Direction @@ -23,8 +25,8 @@ class LearningUtils: self.plot_mean_scores.append(mean_score) def plot(self): - display.clear_output(wait=True) - display.display(plt.gcf()) + # display.clear_output(wait=True) + # display.display(plt.gcf()) plt.clf() plt.title('Results') plt.xlabel('Number of Games') @@ -35,9 +37,12 @@ class LearningUtils: plt.text(len(self.plot_scores) - 1, self.plot_scores[-1], str(self.plot_scores[-1])) plt.text(len(self.plot_mean_scores) - 1, self.plot_mean_scores[-1], str(self.plot_mean_scores[-1])) self.plots += 1 - plt.savefig(f'model/plots/{self.plots}.png') + if MUTATE_NETWORKS: + plt.savefig(f'model/plots/{GeneticAlgorithm.PLOTS_COUNTER}_{self.plots}.png') + else: + plt.savefig(f'model/plots/{self.plots}.png') plt.show(block=False) - plt.pause(.1) + # plt.pause(.1) def append_action(self, action: Action, pos: PositionComponent): self.last_actions.append([action, pos.grid_position]) diff --git a/survival/ai/model.py b/survival/ai/model.py index cde9715..6a86ed4 100644 --- a/survival/ai/model.py +++ b/survival/ai/model.py @@ -1,4 +1,5 @@ import os +import random import torch from torch import nn, optim @@ -6,16 +7,46 @@ import torch.nn.functional as functional class LinearQNetwork(nn.Module): - def __init__(self, input_size, hidden_size, output_size, pretrained=False): + TORCH_ACTiVATIONS = 'tanh' + + def __init__(self, nn_params, input_size, output_size, randomize=True, params=None): super().__init__() - self.linear_one = nn.Linear(input_size, hidden_size) - self.linear_two = nn.Linear(hidden_size, output_size) - self.pretrained = pretrained + self.id = 0 + if params is None: + params = {} + self.params_choice = nn_params + self.scores = [] + self.network_params = params + if randomize: + self.randomize() + self.layers = nn.ModuleList() + if self.network_params['layers'] == 0: + self.layers.append(nn.Linear(input_size, output_size)) + else: + self.layers.append(nn.Linear(input_size, self.network_params['neurons'])) + + for i in range(self.network_params['layers'] - 1): + self.layers.append(nn.Linear(self.network_params['neurons'], self.network_params['neurons'])) + if self.network_params['layers'] > 0: + self.ending_linear = nn.Linear(self.network_params['neurons'], output_size) + self.layers.append(self.ending_linear) + + if self.network_params['activation'] in self.TORCH_ACTiVATIONS: + self.forward_func = getattr(torch, self.network_params['activation']) + else: + self.forward_func = getattr(functional, self.network_params['activation']) + + def randomize(self): + """ + Sets random parameters for network. + """ + for key in self.params_choice: + self.network_params[key] = random.choice(self.params_choice[key]) def forward(self, x): - x = functional.relu(self.linear_one(x)) - x = self.linear_two(x) - + for i in range(len(self.layers) - 1): + x = self.forward_func(self.layers[i](x)) + x = self.layers[-1](x) return x def save(self, file_name='model.pth'): @@ -27,24 +58,25 @@ class LinearQNetwork(nn.Module): torch.save(self.state_dict(), file_path) @staticmethod - def load(input_size, hidden_size, output_size, file_name='model.pth'): + def load(params, input_size, output_size, file_name='model.pth'): model_directory = 'model' file_path = os.path.join(model_directory, file_name) if os.path.isfile(file_path): - model = LinearQNetwork(input_size, hidden_size, output_size, True) + model = LinearQNetwork(params, input_size, output_size, True) model.load_state_dict(torch.load(file_path)) model.eval() return model - return LinearQNetwork(11, 256, 3) + raise Exception(f'Could not find file {file_path}.') class QTrainer: - def __init__(self, model, lr, gamma): + def __init__(self, model, lr, gamma, optimizer): self.model = model self.lr = lr self.gamma = gamma - self.optimizer = optim.Adam(model.parameters(), lr=self.lr) - self.criterion = nn.MSELoss() # Mean squared error + self.optimizer = getattr(optim, optimizer)(model.parameters(), lr=self.lr) + # self.optimizer = optim.Adam(model.parameters(), lr=self.lr) + self.criterion = nn.MSELoss() # Mean squared error def train_step(self, state, action, reward, next_state, done): state = torch.tensor(state, dtype=torch.float) diff --git a/survival/ai/optimizer.py b/survival/ai/optimizer.py new file mode 100644 index 0000000..10a52fe --- /dev/null +++ b/survival/ai/optimizer.py @@ -0,0 +1,147 @@ +from functools import reduce +from operator import add +import random +from typing import List + +from survival.ai.model import LinearQNetwork +from survival.settings import NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE + + +class Optimizer: + def __init__(self, params, retain=0.4, + random_select=0.1, mutation_chance=0.2): + self.mutation_chance = mutation_chance + self.random_select = random_select + self.retain = retain + self.nn_params = params + + def create_population(self, count: int): + """ + Creates 'count' networks from random parameters. + :param count: + :return: + """ + pop = [] + for _ in range(0, count): + # Create a random network. + network = LinearQNetwork(self.nn_params, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE) + # Add network to the population. + pop.append(network) + + return pop + + @staticmethod + def fitness(network: LinearQNetwork): + return sum(network.scores) / len(network.scores) + + def grade(self, pop: List[LinearQNetwork]) -> float: + """ + Finds average fitness for given population. + """ + summed = reduce(add, (self.fitness(network) for network in pop)) + return summed / float((len(pop))) + + def breed(self, parent_one, parent_two): + """ + Creates a new network from given parents. + :param parent_one: + :param parent_two: + :return: + """ + children = [] + for _ in range(2): + + child = {} + + # Loop through the parameters and pick params for the kid. + for param in self.nn_params: + child[param] = random.choice( + [parent_one.network_params[param], parent_two.network_params[param]] + ) + + # Create new network object. + network = LinearQNetwork(self.nn_params, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE) + network.network_params = child + + children.append(network) + + return children + + def mutate(self, network: LinearQNetwork): + """ + Randomly mutates one parameter of the given network. + :param network: + :return: + """ + mutation = random.choice(list(self.nn_params.keys())) + + # Mutate one of the params. + network.network_params[mutation] = random.choice(self.nn_params[mutation]) + + return network + + def evolve(self, pop): + """ + Evolves a population of networks. + """ + # Get scores for each network. + scores = [(self.fitness(network), network) for network in pop] + + # Sort the scores. + scores = [x[1] for x in sorted(scores, key=lambda x: x[0], reverse=True)] + + # Get the number we want to keep for the next gen. + retain_length = int(len(scores) * self.retain) + + # Keep the best networks as parents for next generation. + parents = scores[:retain_length] + + # Keep some other networks + for network in scores[retain_length:]: + if self.random_select > random.random(): + parents.append(network) + + # Reset kept networks + reseted_networks = [] + for network in parents: + net = LinearQNetwork(self.nn_params, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE) + net.network_params = network.network_params + reseted_networks.append(net) + + parents = reseted_networks + + # Randomly mutate some of the networks. + for parent in parents: + if self.mutation_chance > random.random(): + parent = self.mutate(parent) + + # Determine the number of freed spots for the next generation. + parents_length = len(parents) + desired_length = len(pop) - parents_length + children = [] + + # Fill missing spots with new children. + while len(children) < desired_length: + # Get random parents. + p1 = random.randint(0, parents_length - 1) + p2 = random.randint(0, parents_length - 1) + + # Ensure they are not the same network. + if p1 != p2: + p1 = parents[p1] + p2 = parents[p2] + + # Breed networks. + babies = self.breed(p1, p2) + + # Add children one at a time. + for baby in babies: + # Don't grow larger than the desired length. + if len(children) < desired_length: + children.append(baby) + + # parents_params = [n.network_params for n in parents] + # children_params = [n.network_params for n in children] + # parents_params.extend(children_params) + parents.extend(children) + return parents diff --git a/survival/ai/test.py b/survival/ai/test.py new file mode 100644 index 0000000..60978b0 --- /dev/null +++ b/survival/ai/test.py @@ -0,0 +1,93 @@ +import torch +import pygad +from pygad.torchga import torchga + + +def fitness_func(solution, sol_idx): + global data_inputs, data_outputs, torch_ga, model, loss_function + + model_weights_dict = torchga.model_weights_as_dict(model=model, + weights_vector=solution) + + # Use the current solution as the model parameters. + model.load_state_dict(model_weights_dict) + + predictions = model(data_inputs) + abs_error = loss_function(predictions, data_outputs).detach().numpy() + 0.00000001 + + solution_fitness = 1.0 / abs_error + + return solution_fitness + +def callback_generation(ga_instance): + print("Generation = {generation}".format(generation=ga_instance.generations_completed)) + print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) + +# Create the PyTorch model. +input_layer = torch.nn.Linear(3, 2) +relu_layer = torch.nn.ReLU() +output_layer = torch.nn.Linear(2, 1) + +model = torch.nn.Sequential(input_layer, + relu_layer, + output_layer) +# print(model) + +# Create an instance of the pygad.torchga.TorchGA class to build the initial population. +torch_ga = torchga.TorchGA(model=model, + num_solutions=10) + +loss_function = torch.nn.L1Loss() + +# Data inputs +data_inputs = torch.tensor([[0.02, 0.1, 0.15], + [0.7, 0.6, 0.8], + [1.5, 1.2, 1.7], + [3.2, 2.9, 3.1]]) + +# Data outputs +data_outputs = torch.tensor([[0.1], + [0.6], + [1.3], + [2.5]]) + +# Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class +num_generations = 250 # Number of generations. +num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. +initial_population = torch_ga.population_weights # Initial population of network weights +parent_selection_type = "sss" # Type of parent selection. +crossover_type = "single_point" # Type of the crossover operator. +mutation_type = "random" # Type of the mutation operator. +mutation_percent_genes = 10 # Percentage of genes to mutate. This parameter has no action if the parameter mutation_num_genes exists. +keep_parents = -1 # Number of parents to keep in the next population. -1 means keep all parents and 0 means keep nothing. + +ga_instance = pygad.GA(num_generations=num_generations, + num_parents_mating=num_parents_mating, + initial_population=initial_population, + fitness_func=fitness_func, + parent_selection_type=parent_selection_type, + crossover_type=crossover_type, + mutation_type=mutation_type, + mutation_percent_genes=mutation_percent_genes, + keep_parents=keep_parents, + on_generation=callback_generation) + +ga_instance.run() + +# After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. +ga_instance.plot_result(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) + +# Returning the details of the best solution. +solution, solution_fitness, solution_idx = ga_instance.best_solution() +print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) +print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) + +# Fetch the parameters of the best solution. +best_solution_weights = torchga.model_weights_as_dict(model=model, + weights_vector=solution) +model.load_state_dict(best_solution_weights) +predictions = model(data_inputs) +print("Predictions : \n", predictions.detach().numpy()) + +abs_error = loss_function(predictions, data_outputs) +print("Absolute Error : ", abs_error.detach().numpy()) \ No newline at end of file diff --git a/survival/game/user_interface.py b/survival/game/user_interface.py index 9887f9f..937598d 100644 --- a/survival/game/user_interface.py +++ b/survival/game/user_interface.py @@ -1,6 +1,7 @@ import pygame.font -from survival import settings +from survival.ai.genetic_algorithm import GeneticAlgorithm +from survival.settings import MUTATE_NETWORKS, SCREEN_HEIGHT, SCREEN_WIDTH from survival.components.inventory_component import InventoryComponent from survival.generators.resource_type import ResourceType from survival.game.image import Image @@ -8,8 +9,8 @@ from survival.game.image import Image class UserInterface: def __init__(self, window): - self.width = settings.SCREEN_WIDTH - self.height = settings.SCREEN_HEIGHT + self.width = SCREEN_WIDTH + self.height = SCREEN_HEIGHT self.window = window self.pos = (self.width - 240, 50) self.scale = 2 @@ -44,4 +45,3 @@ class UserInterface: textsurface = self.font.render(str(items_count), False, (255, 255, 255)) self.window.blit(textsurface, (image.pos[0] + 48, image.pos[1] + 36)) - diff --git a/survival/generators/tile_generator.py b/survival/generators/tile_generator.py index b505e55..a8b2a31 100644 --- a/survival/generators/tile_generator.py +++ b/survival/generators/tile_generator.py @@ -46,7 +46,8 @@ class TileGenerator: @staticmethod def generate_biome_tiles(width: int, height: int): - seed = random.randint(1, 10) + # Use static seed to allow smooth learning of genetic algorithm + seed = 1 octaves = 10 file_name = f'seeds/{seed}.bin' biomes_file = Path(file_name) diff --git a/survival/generators/world_generator.py b/survival/generators/world_generator.py index 736039b..098544a 100644 --- a/survival/generators/world_generator.py +++ b/survival/generators/world_generator.py @@ -1,4 +1,7 @@ +from pathlib import Path + from survival import esper, ResourceGenerator, PlayerGenerator +from survival.ai.model import LinearQNetwork from survival.components.consumption_component import ConsumptionComponent from survival.components.direction_component import DirectionChangeComponent from survival.components.inventory_component import InventoryComponent @@ -12,7 +15,8 @@ from survival.esper import World from survival.game.camera import Camera from survival.game.game_map import GameMap from survival.generators.resource_type import ResourceType -from survival.settings import PLAYER_START_POSITION, STARTING_RESOURCES_AMOUNT, SCREEN_WIDTH, SCREEN_HEIGHT +from survival.settings import PLAYER_START_POSITION, STARTING_RESOURCES_AMOUNT, SCREEN_WIDTH, SCREEN_HEIGHT, \ + MUTATE_NETWORKS, NETWORK_PARAMS, NEURAL_OUTPUT_SIZE, NEURAL_INPUT_SIZE from survival.systems.automation_system import AutomationSystem from survival.systems.camera_system import CameraSystem from survival.systems.collision_system import CollisionSystem @@ -42,6 +46,14 @@ class WorldGenerator: self.world.add_processor(MovementSystem(self.game_map), priority=20) self.world.add_processor(CollisionSystem(self.game_map), priority=30) self.world.add_processor(NeuralSystem(self.game_map, self.callback), priority=50) + if not MUTATE_NETWORKS: + model_path = Path("/model/model.pth") + if model_path.is_file(): + self.world.get_processor(NeuralSystem).load_model( + LinearQNetwork.load(NETWORK_PARAMS, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE)) + else: + self.world.get_processor(NeuralSystem).load_model( + LinearQNetwork(NETWORK_PARAMS, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE, False, NETWORK_PARAMS)) self.world.add_processor(DrawSystem(self.camera)) self.world.add_processor(TimeSystem()) self.world.add_processor(AutomationSystem(self.game_map)) @@ -51,8 +63,8 @@ class WorldGenerator: self.world.add_processor(VisionSystem(self.camera)) self.player = PlayerGenerator().create_player(self.world, self.game_map) - # self.world.get_processor(DrawSystem).initialize_interface( - # self.world.component_for_entity(self.player, InventoryComponent)) + self.world.get_processor(DrawSystem).initialize_interface( + self.world.component_for_entity(self.player, InventoryComponent)) # BuildingGenerator().create_home(self.world, self.game_map) self.resource_generator.generate_resources(self.player) diff --git a/survival/settings.py b/survival/settings.py index 74dc94a..15c15c8 100644 --- a/survival/settings.py +++ b/survival/settings.py @@ -1,7 +1,18 @@ SCREEN_WIDTH = 1000 SCREEN_HEIGHT = 600 -RESOURCES_AMOUNT = 100 +RESOURCES_AMOUNT = 175 DIRECTION_CHANGE_DELAY = 5 PLAYER_START_POSITION = [20, 10] -STARTING_RESOURCES_AMOUNT = 10 +STARTING_RESOURCES_AMOUNT = 5 AGENT_VISION_RANGE = 5 +NEURAL_INPUT_SIZE = 11 +NEURAL_OUTPUT_SIZE = 3 +LEARN = True +MUTATE_NETWORKS = True +NETWORK_PARAMS = { + "neurons": 256, + "layers": 1, + "activation": 'relu', + "ratio": 0.001, + "optimizer": 'Adam' +} diff --git a/survival/systems/collision_system.py b/survival/systems/collision_system.py index 5ae7ec0..baead60 100644 --- a/survival/systems/collision_system.py +++ b/survival/systems/collision_system.py @@ -23,6 +23,7 @@ class CollisionSystem(esper.Processor): moving.target = tuple(map(operator.add, vector, pos.grid_position)) moving.direction_vector = vector if self.check_collision(moving.target): + self.world.add_component(ent, ConsumeComponent(0.05)) self.world.remove_component(ent, MovingComponent) onCol.call_all() colliding_object: int = self.map.get_entity(moving.target) diff --git a/survival/systems/neural_system.py b/survival/systems/neural_system.py index a7d68b0..3a767aa 100644 --- a/survival/systems/neural_system.py +++ b/survival/systems/neural_system.py @@ -4,6 +4,7 @@ from collections import deque import torch from survival import esper, GameMap +from survival.ai.genetic_algorithm import GeneticAlgorithm from survival.components.direction_component import DirectionChangeComponent from survival.components.inventory_component import InventoryComponent from survival.components.moving_component import MovingComponent @@ -13,11 +14,10 @@ from survival.components.time_component import TimeComponent from survival.ai.graph_search import Action from survival.ai.learning_utils import get_state, LearningUtils from survival.ai.model import LinearQNetwork, QTrainer +from survival.settings import LEARN, MUTATE_NETWORKS MAX_MEMORY = 100_000 BATCH_SIZE = 1000 -LR = 0.001 -LEARN = False class NeuralSystem(esper.Processor): @@ -25,17 +25,27 @@ class NeuralSystem(esper.Processor): self.game_map = game_map self.reset_game = callback self.n_games = 0 # number of games played - self.starting_epsilon = 100 + if MUTATE_NETWORKS: + self.starting_epsilon = GeneticAlgorithm.GAMES_PER_NETWORK / 2 + else: + self.starting_epsilon = 100 self.epsilon = 0 # controlls the randomness self.gamma = 0.9 # discount rate self.memory = deque(maxlen=MAX_MEMORY) # exceeding memory removes the left elements to make more space - self.model = LinearQNetwork.load(11, 256, 3) - if self.model.pretrained: - self.starting_epsilon = -1 - self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma) + self.model = None # self.model = LinearQNetwork.load(11, 256, 3) + self.trainer = None # QTrainer(self.model, lr=LR, gamma=self.gamma) self.utils = LearningUtils() self.best_action = None + def load_model(self, model: LinearQNetwork): + self.model = model + self.trainer = QTrainer(self.model, self.model.network_params['ratio'], self.gamma, + self.model.network_params['optimizer']) + self.utils = LearningUtils() + self.memory = deque(maxlen=MAX_MEMORY) + self.starting_epsilon = GeneticAlgorithm.GAMES_PER_NETWORK / 2 + self.n_games = 0 + def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) @@ -119,11 +129,13 @@ class NeuralSystem(esper.Processor): self.train_long_memory() if learning.score > learning.record: learning.record = learning.score - if LEARN: + if LEARN and not MUTATE_NETWORKS: self.model.save() - print('Game', self.n_games, 'Score', learning.score, 'Record', learning.record) + # print('Game', self.n_games, 'Score', learning.score, 'Record', learning.record) self.utils.add_scores(learning, self.n_games) + + self.model.scores.append(learning.score) learning.score = 0 self.utils.plot()