add exception handling when importing decision tree or neural network
This commit is contained in:
parent
048219b96f
commit
42315ab8d9
34
src/const.py
34
src/const.py
@ -2,8 +2,10 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import pygame as pg
|
import pygame as pg
|
||||||
|
import torch
|
||||||
|
|
||||||
from image import Image
|
from image import Image
|
||||||
|
from machine_learning.neural_network.net import Net
|
||||||
|
|
||||||
|
|
||||||
main_path = os.path.dirname(os.getcwd())
|
main_path = os.path.dirname(os.getcwd())
|
||||||
@ -91,5 +93,33 @@ for name in [
|
|||||||
|
|
||||||
ROCK_INDEXES = (5, 6, 23, 24)
|
ROCK_INDEXES = (5, 6, 23, 24)
|
||||||
|
|
||||||
with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file:
|
try:
|
||||||
TREE_ROOT = pickle.load(file)
|
with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file:
|
||||||
|
TREE_ROOT = pickle.load(file)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print('Decision tree not detected\nInitializing...')
|
||||||
|
|
||||||
|
from machine_learning.decision_tree.decision_tree import main
|
||||||
|
main()
|
||||||
|
|
||||||
|
with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file:
|
||||||
|
TREE_ROOT = pickle.load(file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model')
|
||||||
|
model = Net(num_classes=2)
|
||||||
|
model.load_state_dict(checkpoint)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print('Neural network not detected\nInitializing...')
|
||||||
|
|
||||||
|
from machine_learning.neural_network.learning import main
|
||||||
|
main()
|
||||||
|
|
||||||
|
checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model')
|
||||||
|
|
||||||
|
model = Net(num_classes=2)
|
||||||
|
model.load_state_dict(checkpoint)
|
||||||
|
model.eval()
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
|
from machine_learning.decision_tree.helpers import path
|
||||||
|
|
||||||
data_path = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
visibility = ('bad', 'medium', 'good')
|
visibility = ('bad', 'medium', 'good')
|
||||||
stability = ('unstable', 'stable')
|
stability = ('unstable', 'stable')
|
||||||
@ -67,7 +66,7 @@ def main():
|
|||||||
json.dump(data_set, outfile)
|
json.dump(data_set, outfile)
|
||||||
|
|
||||||
|
|
||||||
with open(f'{data_path}/data.txt', 'r') as f:
|
with open(f'{path}/data.txt', 'r') as f:
|
||||||
file_data = f.read()
|
file_data = f.read()
|
||||||
json_data = json.loads(file_data)
|
json_data = json.loads(file_data)
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from collections import Counter
|
|||||||
|
|
||||||
from anytree import Node, RenderTree
|
from anytree import Node, RenderTree
|
||||||
|
|
||||||
|
from machine_learning.decision_tree.helpers import path
|
||||||
from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs
|
from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs
|
||||||
|
|
||||||
|
|
||||||
@ -84,7 +85,7 @@ def main():
|
|||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
print(f'Accuracy: {score/len(test_set)}')
|
print(f'Accuracy: {score/len(test_set)}')
|
||||||
with open('decision_tree.pkl', 'wb') as file:
|
with open(f'{path}/decision_tree.pkl', 'wb') as file:
|
||||||
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from tile import Tile
|
from tile import Tile
|
||||||
|
|
||||||
|
|
||||||
@ -10,3 +11,6 @@ def get_dataset_from_tile(tile: Tile):
|
|||||||
'mine_type': tile.mine.mine_type,
|
'mine_type': tile.mine.mine_type,
|
||||||
'pressure_gt_two': True if tile.mine.pressure > 2 else False
|
'pressure_gt_two': True if tile.mine.pressure > 2 else False
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
Loading…
Reference in New Issue
Block a user