From 42315ab8d9e2f89fd96c034b2ee651dbd0c5d927 Mon Sep 17 00:00:00 2001 From: matixezor Date: Mon, 31 May 2021 23:15:27 +0200 Subject: [PATCH] add exception handling when importing decision tree or neural network --- src/const.py | 34 +++++++++++++++++-- .../decision_tree/data_set.py | 5 ++- .../decision_tree/decision_tree.py | 3 +- src/machine_learning/decision_tree/helpers.py | 4 +++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/const.py b/src/const.py index 9c4b00e..39d9f12 100644 --- a/src/const.py +++ b/src/const.py @@ -2,8 +2,10 @@ import os import pickle import pygame as pg +import torch from image import Image +from machine_learning.neural_network.net import Net main_path = os.path.dirname(os.getcwd()) @@ -91,5 +93,33 @@ for name in [ ROCK_INDEXES = (5, 6, 23, 24) -with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file: - TREE_ROOT = pickle.load(file) +try: + with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file: + TREE_ROOT = pickle.load(file) + +except FileNotFoundError: + print('Decision tree not detected\nInitializing...') + + from machine_learning.decision_tree.decision_tree import main + main() + + with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file: + TREE_ROOT = pickle.load(file) + +try: + checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model') + model = Net(num_classes=2) + model.load_state_dict(checkpoint) + model.eval() + +except FileNotFoundError: + print('Neural network not detected\nInitializing...') + + from machine_learning.neural_network.learning import main + main() + + checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model') + + model = Net(num_classes=2) + model.load_state_dict(checkpoint) + model.eval() diff --git a/src/machine_learning/decision_tree/data_set.py b/src/machine_learning/decision_tree/data_set.py index a4b7034..440562a 100644 --- a/src/machine_learning/decision_tree/data_set.py +++ b/src/machine_learning/decision_tree/data_set.py @@ -1,10 +1,9 @@ -import os import json from typing import List from itertools import product +from machine_learning.decision_tree.helpers import path -data_path = os.path.dirname(os.path.abspath(__file__)) visibility = ('bad', 'medium', 'good') stability = ('unstable', 'stable') @@ -67,7 +66,7 @@ def main(): json.dump(data_set, outfile) -with open(f'{data_path}/data.txt', 'r') as f: +with open(f'{path}/data.txt', 'r') as f: file_data = f.read() json_data = json.loads(file_data) diff --git a/src/machine_learning/decision_tree/decision_tree.py b/src/machine_learning/decision_tree/decision_tree.py index 255d3e7..5989137 100644 --- a/src/machine_learning/decision_tree/decision_tree.py +++ b/src/machine_learning/decision_tree/decision_tree.py @@ -5,6 +5,7 @@ from collections import Counter from anytree import Node, RenderTree +from machine_learning.decision_tree.helpers import path from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs @@ -84,7 +85,7 @@ def main(): score += 1 print(f'Accuracy: {score/len(test_set)}') - with open('decision_tree.pkl', 'wb') as file: + with open(f'{path}/decision_tree.pkl', 'wb') as file: pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL) diff --git a/src/machine_learning/decision_tree/helpers.py b/src/machine_learning/decision_tree/helpers.py index b7ab0df..a31c2e0 100644 --- a/src/machine_learning/decision_tree/helpers.py +++ b/src/machine_learning/decision_tree/helpers.py @@ -1,3 +1,4 @@ +import os from tile import Tile @@ -10,3 +11,6 @@ def get_dataset_from_tile(tile: Tile): 'mine_type': tile.mine.mine_type, 'pressure_gt_two': True if tile.mine.pressure > 2 else False } + + +path = os.path.dirname(os.path.abspath(__file__))