use pickle to save decision_tree, adjust code to use saved obj
This commit is contained in:
parent
471e5ee6d0
commit
4a942dd29e
12
src/const.py
12
src/const.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import pygame as pg
|
||||
|
||||
@ -54,7 +55,7 @@ for name, val in {
|
||||
val,
|
||||
name,
|
||||
pg.image.load(
|
||||
main_path + '/images/Tiles/' + folder_name + ' Tiles/' + name + ".png"
|
||||
f'{main_path}/images/Tiles/{folder_name} Tiles/{name}.png'
|
||||
),
|
||||
None if 'mine' not in name else (
|
||||
'AP' if 'AP' in name else (
|
||||
@ -64,7 +65,7 @@ for name, val in {
|
||||
)
|
||||
)
|
||||
|
||||
ICON = main_path + '/images/mine_icon.png'
|
||||
ICON = f'{main_path}/images/mine_icon.png'
|
||||
|
||||
SAPPER_IDLE = []
|
||||
for name in [
|
||||
@ -74,7 +75,7 @@ for name in [
|
||||
'agent3',
|
||||
]:
|
||||
SAPPER_IDLE.append(pg.image.load(
|
||||
main_path + '/images/agent/agent_idle/' + name + '.png'
|
||||
f'{main_path}/images/agent/agent_idle/{name}.png'
|
||||
))
|
||||
|
||||
SAPPER_KABOOM = []
|
||||
@ -85,7 +86,10 @@ for name in [
|
||||
'agent_kaboom_3',
|
||||
]:
|
||||
SAPPER_KABOOM.append(pg.image.load(
|
||||
main_path + '/images/agent/agent_kaboom/' + name + '.png'
|
||||
f'{main_path}/images/agent/agent_kaboom/{name}.png'
|
||||
))
|
||||
|
||||
ROCK_INDEXES = (5, 6, 23, 24)
|
||||
|
||||
with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file:
|
||||
TREE_ROOT = pickle.load(file)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import pickle
|
||||
from math import log
|
||||
from typing import List, Dict
|
||||
from collections import Counter
|
||||
@ -70,6 +71,7 @@ def get_decision(data: dict, root: Node) -> str:
|
||||
|
||||
|
||||
def main():
|
||||
tree_root = tree_learn(training_set, attribs, 'detonation')
|
||||
print(RenderTree(tree_root))
|
||||
print('-' * 150)
|
||||
|
||||
@ -82,9 +84,8 @@ def main():
|
||||
score += 1
|
||||
|
||||
print(f'Accuracy: {score/len(test_set)}')
|
||||
|
||||
|
||||
tree_root = tree_learn(training_set, attribs, 'detonation')
|
||||
with open('decision_tree.pkl', 'wb') as file:
|
||||
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,13 +2,13 @@ import pygame as pg
|
||||
|
||||
from agent import Agent
|
||||
from game_ui import GameUi
|
||||
from const import ICON, IMAGES
|
||||
from environment import Environment
|
||||
from tilesFactory import TilesFactory
|
||||
from const import ICON, IMAGES, TREE_ROOT
|
||||
from src.search_algoritms.a_star import a_star
|
||||
from search_algoritms.BFS import breadth_first_search
|
||||
from machine_learning.decision_tree import get_decision
|
||||
from machine_learning.helpers import get_dataset_from_tile
|
||||
from machine_learning.decision_tree import get_decision, tree_root
|
||||
|
||||
|
||||
def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesFactory):
|
||||
@ -33,7 +33,7 @@ def handle_keys(env: Environment, agent: Agent, game_ui: GameUi, factory: TilesF
|
||||
tile = env.field[agent.y][agent.x]
|
||||
|
||||
dataset = get_dataset_from_tile(tile)
|
||||
decision = get_decision(dataset, tree_root)
|
||||
decision = get_decision(dataset, TREE_ROOT)
|
||||
|
||||
print(f'Data: {dataset}')
|
||||
print(f'Decision: {decision}')
|
||||
|
Loading…
Reference in New Issue
Block a user