148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
from functools import reduce
|
|
from operator import add
|
|
import random
|
|
from typing import List
|
|
|
|
from survival.ai.model import LinearQNetwork
|
|
from survival.settings import NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE
|
|
|
|
|
|
class Optimizer:
|
|
def __init__(self, params, retain=0.4,
|
|
random_select=0.1, mutation_chance=0.2):
|
|
self.mutation_chance = mutation_chance
|
|
self.random_select = random_select
|
|
self.retain = retain
|
|
self.nn_params = params
|
|
|
|
def create_population(self, count: int):
|
|
"""
|
|
Creates 'count' networks from random parameters.
|
|
:param count:
|
|
:return:
|
|
"""
|
|
pop = []
|
|
for _ in range(0, count):
|
|
# Create a random network.
|
|
network = LinearQNetwork(self.nn_params, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE)
|
|
# Add network to the population.
|
|
pop.append(network)
|
|
|
|
return pop
|
|
|
|
@staticmethod
|
|
def fitness(network: LinearQNetwork):
|
|
return sum(network.scores) / len(network.scores)
|
|
|
|
def grade(self, pop: List[LinearQNetwork]) -> float:
|
|
"""
|
|
Finds average fitness for given population.
|
|
"""
|
|
summed = reduce(add, (self.fitness(network) for network in pop))
|
|
return summed / float((len(pop)))
|
|
|
|
def breed(self, parent_one, parent_two):
|
|
"""
|
|
Creates a new network from given parents.
|
|
:param parent_one:
|
|
:param parent_two:
|
|
:return:
|
|
"""
|
|
children = []
|
|
for _ in range(2):
|
|
|
|
child = {}
|
|
|
|
# Loop through the parameters and pick params for the kid.
|
|
for param in self.nn_params:
|
|
child[param] = random.choice(
|
|
[parent_one.network_params[param], parent_two.network_params[param]]
|
|
)
|
|
|
|
# Create new network object.
|
|
network = LinearQNetwork(self.nn_params, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE)
|
|
network.network_params = child
|
|
|
|
children.append(network)
|
|
|
|
return children
|
|
|
|
def mutate(self, network: LinearQNetwork):
|
|
"""
|
|
Randomly mutates one parameter of the given network.
|
|
:param network:
|
|
:return:
|
|
"""
|
|
mutation = random.choice(list(self.nn_params.keys()))
|
|
|
|
# Mutate one of the params.
|
|
network.network_params[mutation] = random.choice(self.nn_params[mutation])
|
|
|
|
return network
|
|
|
|
def evolve(self, pop):
|
|
"""
|
|
Evolves a population of networks.
|
|
"""
|
|
# Get scores for each network.
|
|
scores = [(self.fitness(network), network) for network in pop]
|
|
|
|
# Sort the scores.
|
|
scores = [x[1] for x in sorted(scores, key=lambda x: x[0], reverse=True)]
|
|
|
|
# Get the number we want to keep for the next gen.
|
|
retain_length = int(len(scores) * self.retain)
|
|
|
|
# Keep the best networks as parents for next generation.
|
|
parents = scores[:retain_length]
|
|
|
|
# Keep some other networks
|
|
for network in scores[retain_length:]:
|
|
if self.random_select > random.random():
|
|
parents.append(network)
|
|
|
|
# Reset kept networks
|
|
reseted_networks = []
|
|
for network in parents:
|
|
net = LinearQNetwork(self.nn_params, NEURAL_INPUT_SIZE, NEURAL_OUTPUT_SIZE)
|
|
net.network_params = network.network_params
|
|
reseted_networks.append(net)
|
|
|
|
parents = reseted_networks
|
|
|
|
# Randomly mutate some of the networks.
|
|
for parent in parents:
|
|
if self.mutation_chance > random.random():
|
|
parent = self.mutate(parent)
|
|
|
|
# Determine the number of freed spots for the next generation.
|
|
parents_length = len(parents)
|
|
desired_length = len(pop) - parents_length
|
|
children = []
|
|
|
|
# Fill missing spots with new children.
|
|
while len(children) < desired_length:
|
|
# Get random parents.
|
|
p1 = random.randint(0, parents_length - 1)
|
|
p2 = random.randint(0, parents_length - 1)
|
|
|
|
# Ensure they are not the same network.
|
|
if p1 != p2:
|
|
p1 = parents[p1]
|
|
p2 = parents[p2]
|
|
|
|
# Breed networks.
|
|
babies = self.breed(p1, p2)
|
|
|
|
# Add children one at a time.
|
|
for baby in babies:
|
|
# Don't grow larger than the desired length.
|
|
if len(children) < desired_length:
|
|
children.append(baby)
|
|
|
|
# parents_params = [n.network_params for n in parents]
|
|
# children_params = [n.network_params for n in children]
|
|
# parents_params.extend(children_params)
|
|
parents.extend(children)
|
|
return parents
|