add exception handling when importing decision tree or neural network
This commit is contained in:
parent
048219b96f
commit
42315ab8d9
32
src/const.py
32
src/const.py
@ -2,8 +2,10 @@ import os
|
||||
import pickle
|
||||
|
||||
import pygame as pg
|
||||
import torch
|
||||
|
||||
from image import Image
|
||||
from machine_learning.neural_network.net import Net
|
||||
|
||||
|
||||
main_path = os.path.dirname(os.getcwd())
|
||||
@ -91,5 +93,33 @@ for name in [
|
||||
|
||||
ROCK_INDEXES = (5, 6, 23, 24)
|
||||
|
||||
with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file:
|
||||
try:
|
||||
with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file:
|
||||
TREE_ROOT = pickle.load(file)
|
||||
|
||||
except FileNotFoundError:
|
||||
print('Decision tree not detected\nInitializing...')
|
||||
|
||||
from machine_learning.decision_tree.decision_tree import main
|
||||
main()
|
||||
|
||||
with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file:
|
||||
TREE_ROOT = pickle.load(file)
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model')
|
||||
model = Net(num_classes=2)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
|
||||
except FileNotFoundError:
|
||||
print('Neural network not detected\nInitializing...')
|
||||
|
||||
from machine_learning.neural_network.learning import main
|
||||
main()
|
||||
|
||||
checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model')
|
||||
|
||||
model = Net(num_classes=2)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
|
@ -1,10 +1,9 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List
|
||||
from itertools import product
|
||||
|
||||
from machine_learning.decision_tree.helpers import path
|
||||
|
||||
data_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
visibility = ('bad', 'medium', 'good')
|
||||
stability = ('unstable', 'stable')
|
||||
@ -67,7 +66,7 @@ def main():
|
||||
json.dump(data_set, outfile)
|
||||
|
||||
|
||||
with open(f'{data_path}/data.txt', 'r') as f:
|
||||
with open(f'{path}/data.txt', 'r') as f:
|
||||
file_data = f.read()
|
||||
json_data = json.loads(file_data)
|
||||
|
||||
|
@ -5,6 +5,7 @@ from collections import Counter
|
||||
|
||||
from anytree import Node, RenderTree
|
||||
|
||||
from machine_learning.decision_tree.helpers import path
|
||||
from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs
|
||||
|
||||
|
||||
@ -84,7 +85,7 @@ def main():
|
||||
score += 1
|
||||
|
||||
print(f'Accuracy: {score/len(test_set)}')
|
||||
with open('decision_tree.pkl', 'wb') as file:
|
||||
with open(f'{path}/decision_tree.pkl', 'wb') as file:
|
||||
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from tile import Tile
|
||||
|
||||
|
||||
@ -10,3 +11,6 @@ def get_dataset_from_tile(tile: Tile):
|
||||
'mine_type': tile.mine.mine_type,
|
||||
'pressure_gt_two': True if tile.mine.pressure > 2 else False
|
||||
}
|
||||
|
||||
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
Loading…
Reference in New Issue
Block a user