use pickle to save decision_tree, adjust code to use saved obj
This commit is contained in:
parent
471e5ee6d0
commit
4a942dd29e
12
src/const.py
12
src/const.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
import pygame as pg
|
import pygame as pg
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ for name, val in {
|
|||||||
val,
|
val,
|
||||||
name,
|
name,
|
||||||
pg.image.load(
|
pg.image.load(
|
||||||
main_path + '/images/Tiles/' + folder_name + ' Tiles/' + name + ".png"
|
f'{main_path}/images/Tiles/{folder_name} Tiles/{name}.png'
|
||||||
),
|
),
|
||||||
None if 'mine' not in name else (
|
None if 'mine' not in name else (
|
||||||
'AP' if 'AP' in name else (
|
'AP' if 'AP' in name else (
|
||||||
@ -64,7 +65,7 @@ for name, val in {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
ICON = main_path + '/images/mine_icon.png'
|
ICON = f'{main_path}/images/mine_icon.png'
|
||||||
|
|
||||||
SAPPER_IDLE = []
|
SAPPER_IDLE = []
|
||||||
for name in [
|
for name in [
|
||||||
@ -74,7 +75,7 @@ for name in [
|
|||||||
'agent3',
|
'agent3',
|
||||||
]:
|
]:
|
||||||
SAPPER_IDLE.append(pg.image.load(
|
SAPPER_IDLE.append(pg.image.load(
|
||||||
main_path + '/images/agent/agent_idle/' + name + '.png'
|
f'{main_path}/images/agent/agent_idle/{name}.png'
|
||||||
))
|
))
|
||||||
|
|
||||||
SAPPER_KABOOM = []
|
SAPPER_KABOOM = []
|
||||||
@ -85,7 +86,10 @@ for name in [
|
|||||||
'agent_kaboom_3',
|
'agent_kaboom_3',
|
||||||
]:
|
]:
|
||||||
SAPPER_KABOOM.append(pg.image.load(
|
SAPPER_KABOOM.append(pg.image.load(
|
||||||
main_path + '/images/agent/agent_kaboom/' + name + '.png'
|
f'{main_path}/images/agent/agent_kaboom/{name}.png'
|
||||||
))
|
))
|
||||||
|
|
||||||
ROCK_INDEXES = (5, 6, 23, 24)
|
ROCK_INDEXES = (5, 6, 23, 24)
|
||||||
|
|
||||||
|
with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file:
|
||||||
|
TREE_ROOT = pickle.load(file)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pickle
|
||||||
from math import log
|
from math import log
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
@ -70,6 +71,7 @@ def get_decision(data: dict, root: Node) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
tree_root = tree_learn(training_set, attribs, 'detonation')
|
||||||
print(RenderTree(tree_root))
|
print(RenderTree(tree_root))
|
||||||
print('-' * 150)
|
print('-' * 150)
|
||||||
|
|
||||||
@ -82,9 +84,8 @@ def main():
|
|||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
print(f'Accuracy: {score/len(test_set)}')
|
print(f'Accuracy: {score/len(test_set)}')
|
||||||
|
with open('decision_tree.pkl', 'wb') as file:
|
||||||
|
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)
|
||||||
tree_root = tree_learn(training_set, attribs, 'detonation')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,13 +2,13 @@ import pygame as pg
|
|||||||
|
|
||||||
from agent import Agent
|
from agent import Agent
|
||||||
from game_ui import GameUi
|
from game_ui import GameUi
|
||||||
from const import ICON, IMAGES
|
|
||||||
from environment import Environment
|
from environment import Environment
|
||||||
from tilesFactory import TilesFactory
|
from tilesFactory import TilesFactory
|
||||||
|
from const import ICON, IMAGES, TREE_ROOT
|
||||||
from src.search_algoritms.a_star import a_star
|
from src.search_algoritms.a_star import a_star
|
||||||
from search_algoritms.BFS import breadth_first_search
|
from search_algoritms.BFS import breadth_first_search
|
||||||
|
from machine_learning.decision_tree import get_decision
|
||||||
from machine_learning.helpers import get_dataset_from_tile
|
from machine_learning.helpers import get_dataset_from_tile
|
||||||
from machine_learning.decision_tree import get_decision, tree_root
|
|
||||||
|
|
||||||
|
|
||||||
def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
|
def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
|
||||||
@ -33,7 +33,7 @@ def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesF
|
|||||||
tile = env.field[agent.y][agent.x]
|
tile = env.field[agent.y][agent.x]
|
||||||
|
|
||||||
dataset = get_dataset_from_tile(tile)
|
dataset = get_dataset_from_tile(tile)
|
||||||
decision = get_decision(dataset, tree_root)
|
decision = get_decision(dataset, TREE_ROOT)
|
||||||
|
|
||||||
print(f'Data: {dataset}')
|
print(f'Data: {dataset}')
|
||||||
print(f'Decision: {decision}')
|
print(f'Decision: {decision}')
|
||||||
|
Loading…
Reference in New Issue
Block a user