AI-Project/survival/ai/genetic_algorithm.py
2021-06-21 12:20:25 +02:00

85 lines
3.4 KiB
Python

import sys
from survival.ai.model import LinearQNetwork
from survival.ai.optimizer import Optimizer
class GeneticAlgorithm:
GAMES_PER_NETWORK = 40
PLOTS_COUNTER = 0
CURRENT_GENERATION = 1
def __init__(self, neural_system, callback):
self.callback = callback
self.logs_file = open('genetic_logs.txt', 'w')
self.original_stdout = sys.stdout
sys.stdout = self.logs_file
self.neural_system = neural_system
self.generations = 20
self.population = 10 # Minimum 5 needed to allow breeding
self.nn_params = {
'neurons': [128, 192, 256, 384, 512],
'layers': [0, 1, 2, 3],
'activation': ['relu', 'elu', 'tanh'],
'ratio': [0.0007, 0.0009, 0.0011, 0.0013, 0.0015],
'optimizer': ['RMSprop', 'Adam', 'SGD', 'Adagrad', 'Adadelta'],
}
self.optimizer = Optimizer(self.nn_params)
self.networks: list[LinearQNetwork] = self.optimizer.create_population(self.population)
self.finished = False
self.trained_counter = 0
self.iterations = 0
self.trained_generations = 0
print('Started generation 1...')
self.change_network(self.networks[0])
def train(self):
if self.iterations < GeneticAlgorithm.GAMES_PER_NETWORK - 1:
self.iterations += 1
return
self.iterations = 0
print(f'Network score: {self.optimizer.fitness(self.networks[self.trained_counter])}')
self.trained_counter += 1
# If all networks in current population were trained
if self.trained_counter >= self.population:
# Get average score in current population
avg_score = self.calculate_average_score(self.networks)
print(f'Average population score: {avg_score}.')
results_file = open('genetic_results.txt', 'w')
for network in self.networks:
results_file.write(
f'Network {network.id} params {network.network_params}. Avg score = {sum(network.scores) / len(network.scores)}\n')
results_file.close()
if self.trained_generations >= self.generations - 1:
# Sort the final population
self.networks = sorted(self.networks, key=lambda x: sum(x.scores) / len(x.scores), reverse=True)
self.finished = True
self.logs_file.close()
sys.stdout = self.original_stdout
self.callback()
return
self.trained_generations += 1
GeneticAlgorithm.CURRENT_GENERATION = self.trained_generations + 1
print(f'Started generation {GeneticAlgorithm.CURRENT_GENERATION}...')
self.networks = self.optimizer.evolve(self.networks)
self.trained_counter = 0
self.change_network(self.networks[self.trained_counter])
def calculate_average_score(self, networks):
sums = 0
lengths = 0
for network in networks:
sums += self.optimizer.fitness(network)
lengths += len(network.scores)
return sums / lengths
def change_network(self, net):
GeneticAlgorithm.PLOTS_COUNTER += 1
print(f"Changed network to {GeneticAlgorithm.PLOTS_COUNTER} {net.network_params}")
self.logs_file.flush()
net.id = GeneticAlgorithm.PLOTS_COUNTER
self.neural_system.load_model(net)