use pickle to save decision_tree, adjust code to use saved obj

This commit is contained in:
matixezor 2021-05-19 21:19:25 +02:00
parent 471e5ee6d0
commit 4a942dd29e
3 changed files with 15 additions and 10 deletions

View File

@ -1,4 +1,5 @@
import os import os
import pickle
import pygame as pg import pygame as pg
@ -54,7 +55,7 @@ for name, val in {
val, val,
name, name,
pg.image.load( pg.image.load(
main_path + '/images/Tiles/' + folder_name + ' Tiles/' + name + ".png" f'{main_path}/images/Tiles/{folder_name} Tiles/{name}.png'
), ),
None if 'mine' not in name else ( None if 'mine' not in name else (
'AP' if 'AP' in name else ( 'AP' if 'AP' in name else (
@ -64,7 +65,7 @@ for name, val in {
) )
) )
ICON = main_path + '/images/mine_icon.png' ICON = f'{main_path}/images/mine_icon.png'
SAPPER_IDLE = [] SAPPER_IDLE = []
for name in [ for name in [
@ -74,7 +75,7 @@ for name in [
'agent3', 'agent3',
]: ]:
SAPPER_IDLE.append(pg.image.load( SAPPER_IDLE.append(pg.image.load(
main_path + '/images/agent/agent_idle/' + name + '.png' f'{main_path}/images/agent/agent_idle/{name}.png'
)) ))
SAPPER_KABOOM = [] SAPPER_KABOOM = []
@ -85,7 +86,10 @@ for name in [
'agent_kaboom_3', 'agent_kaboom_3',
]: ]:
SAPPER_KABOOM.append(pg.image.load( SAPPER_KABOOM.append(pg.image.load(
main_path + '/images/agent/agent_kaboom/' + name + '.png' f'{main_path}/images/agent/agent_kaboom/{name}.png'
)) ))
ROCK_INDEXES = (5, 6, 23, 24) ROCK_INDEXES = (5, 6, 23, 24)
with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file:
TREE_ROOT = pickle.load(file)

View File

@ -1,3 +1,4 @@
import pickle
from math import log from math import log
from typing import List, Dict from typing import List, Dict
from collections import Counter from collections import Counter
@ -70,6 +71,7 @@ def get_decision(data: dict, root: Node) -> str:
def main(): def main():
tree_root = tree_learn(training_set, attribs, 'detonation')
print(RenderTree(tree_root)) print(RenderTree(tree_root))
print('-' * 150) print('-' * 150)
@ -82,9 +84,8 @@ def main():
score += 1 score += 1
print(f'Accuracy: {score/len(test_set)}') print(f'Accuracy: {score/len(test_set)}')
with open('decision_tree.pkl', 'wb') as file:
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)
tree_root = tree_learn(training_set, attribs, 'detonation')
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,13 +2,13 @@ import pygame as pg
from agent import Agent from agent import Agent
from game_ui import GameUi from game_ui import GameUi
from const import ICON, IMAGES
from environment import Environment from environment import Environment
from tilesFactory import TilesFactory from tilesFactory import TilesFactory
from const import ICON, IMAGES, TREE_ROOT
from src.search_algoritms.a_star import a_star from src.search_algoritms.a_star import a_star
from search_algoritms.BFS import breadth_first_search from search_algoritms.BFS import breadth_first_search
from machine_learning.decision_tree import get_decision
from machine_learning.helpers import get_dataset_from_tile from machine_learning.helpers import get_dataset_from_tile
from machine_learning.decision_tree import get_decision, tree_root
def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory): def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
@ -33,7 +33,7 @@ def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesF
tile = env.field[agent.y][agent.x] tile = env.field[agent.y][agent.x]
dataset = get_dataset_from_tile(tile) dataset = get_dataset_from_tile(tile)
decision = get_decision(dataset, tree_root) decision = get_decision(dataset, TREE_ROOT)
print(f'Data: {dataset}') print(f'Data: {dataset}')
print(f'Decision: {decision}') print(f'Decision: {decision}')