From 4a942dd29e0ed91ba5641ebf1e17bc2665bb6e40 Mon Sep 17 00:00:00 2001 From: matixezor Date: Wed, 19 May 2021 21:19:25 +0200 Subject: [PATCH 1/2] use pickle to save decision_tree, adjust code to use saved obj --- src/const.py | 12 ++++++++---- src/machine_learning/decision_tree.py | 7 ++++--- src/main.py | 6 +++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/const.py b/src/const.py index e60915b..9c4b00e 100644 --- a/src/const.py +++ b/src/const.py @@ -1,4 +1,5 @@ import os +import pickle import pygame as pg @@ -54,7 +55,7 @@ for name, val in { val, name, pg.image.load( - main_path + '/images/Tiles/' + folder_name + ' Tiles/' + name + ".png" + f'{main_path}/images/Tiles/{folder_name} Tiles/{name}.png' ), None if 'mine' not in name else ( 'AP' if 'AP' in name else ( @@ -64,7 +65,7 @@ for name, val in { ) ) -ICON = main_path + '/images/mine_icon.png' +ICON = f'{main_path}/images/mine_icon.png' SAPPER_IDLE = [] for name in [ @@ -74,7 +75,7 @@ for name in [ 'agent3', ]: SAPPER_IDLE.append(pg.image.load( - main_path + '/images/agent/agent_idle/' + name + '.png' + f'{main_path}/images/agent/agent_idle/{name}.png' )) SAPPER_KABOOM = [] @@ -85,7 +86,10 @@ for name in [ 'agent_kaboom_3', ]: SAPPER_KABOOM.append(pg.image.load( - main_path + '/images/agent/agent_kaboom/' + name + '.png' + f'{main_path}/images/agent/agent_kaboom/{name}.png' )) ROCK_INDEXES = (5, 6, 23, 24) + +with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file: + TREE_ROOT = pickle.load(file) diff --git a/src/machine_learning/decision_tree.py b/src/machine_learning/decision_tree.py index 5d93991..2da317d 100644 --- a/src/machine_learning/decision_tree.py +++ b/src/machine_learning/decision_tree.py @@ -1,3 +1,4 @@ +import pickle from math import log from typing import List, Dict from collections import Counter @@ -70,6 +71,7 @@ def get_decision(data: dict, root: Node) -> str: def main(): + tree_root = tree_learn(training_set, attribs, 'detonation') print(RenderTree(tree_root)) print('-' * 150) @@ -82,9 +84,8 @@ def main(): score += 1 print(f'Accuracy: {score/len(test_set)}') - - -tree_root = tree_learn(training_set, attribs, 'detonation') + with open('decision_tree.pkl', 'wb') as file: + pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL) if __name__ == "__main__": diff --git a/src/main.py b/src/main.py index ef3d03d..cce13d1 100644 --- a/src/main.py +++ b/src/main.py @@ -2,13 +2,13 @@ import pygame as pg from agent import Agent from game_ui import GameUi -from const import ICON, IMAGES from environment import Environment from tilesFactory import TilesFactory +from const import ICON, IMAGES, TREE_ROOT from src.search_algoritms.a_star import a_star from search_algoritms.BFS import breadth_first_search +from machine_learning.decision_tree import get_decision from machine_learning.helpers import get_dataset_from_tile -from machine_learning.decision_tree import get_decision, tree_root def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory): @@ -33,7 +33,7 @@ def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesF tile = env.field[agent.y][agent.x] dataset = get_dataset_from_tile(tile) - decision = get_decision(dataset, tree_root) + decision = get_decision(dataset, TREE_ROOT) print(f'Data: {dataset}') print(f'Decision: {decision}') -- 2.20.1 From f82e326e3d075d6c19b0dac44e81839813264d6b Mon Sep 17 00:00:00 2001 From: matixezor Date: Wed, 19 May 2021 21:19:56 +0200 Subject: [PATCH 2/2] add decision_tree pickle obj --- src/machine_learning/decision_tree.pkl | Bin 0 -> 1589 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/machine_learning/decision_tree.pkl diff --git a/src/machine_learning/decision_tree.pkl b/src/machine_learning/decision_tree.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3604db75e888a6084902ac981ea375c677a70e74 GIT binary patch literal 1589 zcmbVMu};G<6hv(r+Ag%HD}u3r=!DoSAu$ygSW%SJCN^rE%5i`WNFXtQ3X6jZANy7|~mh2?L59o}Fg(BOlr7prEX1<^?Q*zhFn)?7G z6zNC@qYxO|DbqMJyaY)Di=g{J6Iy|R^vHnpWi#;6I*1$cIAx|_Z*Q*%0DBrBN(%rA z8e1XvlVqK%6H$gev6ve%Hl2&`U7EoA9ytc8#`tD^d}q#alB67aq430GyzD3eIVJsf z?fR-=mhC_UJs+kA|I=QtPUO~Vj?X9AUV22E(nWl6ka=j6k(2zPTUZw*63N+yp{oj# zpvRJAXv!t&dV!?&G-gwd1qLs)OdTE