neural_network #4

Merged
s452622 merged 34 commits from neural_network into master 2021-06-08 23:47:22 +02:00
4 changed files with 40 additions and 6 deletions
Showing only changes of commit 42315ab8d9 - Show all commits

View File

@ -2,8 +2,10 @@ import os
import pickle import pickle
import pygame as pg import pygame as pg
import torch
from image import Image from image import Image
from machine_learning.neural_network.net import Net
main_path = os.path.dirname(os.getcwd()) main_path = os.path.dirname(os.getcwd())
@ -91,5 +93,33 @@ for name in [
ROCK_INDEXES = (5, 6, 23, 24) ROCK_INDEXES = (5, 6, 23, 24)
with open(f'{main_path}/src/machine_learning/decision_tree.pkl', 'rb') as file: try:
with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file:
TREE_ROOT = pickle.load(file) TREE_ROOT = pickle.load(file)
except FileNotFoundError:
print('Decision tree not detected\nInitializing...')
from machine_learning.decision_tree.decision_tree import main
main()
with open(f'{main_path}/src/machine_learning/decision_tree/decision_tree.pkl', 'rb') as file:
TREE_ROOT = pickle.load(file)
try:
checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model')
model = Net(num_classes=2)
model.load_state_dict(checkpoint)
model.eval()
except FileNotFoundError:
print('Neural network not detected\nInitializing...')
from machine_learning.neural_network.learning import main
main()
checkpoint = torch.load(f'{main_path}/src/machine_learning/neural_network/mine_recognizer.model')
model = Net(num_classes=2)
model.load_state_dict(checkpoint)
model.eval()

View File

@ -1,10 +1,9 @@
import os
import json import json
from typing import List from typing import List
from itertools import product from itertools import product
from machine_learning.decision_tree.helpers import path
data_path = os.path.dirname(os.path.abspath(__file__))
visibility = ('bad', 'medium', 'good') visibility = ('bad', 'medium', 'good')
stability = ('unstable', 'stable') stability = ('unstable', 'stable')
@ -67,7 +66,7 @@ def main():
json.dump(data_set, outfile) json.dump(data_set, outfile)
with open(f'{data_path}/data.txt', 'r') as f: with open(f'{path}/data.txt', 'r') as f:
file_data = f.read() file_data = f.read()
json_data = json.loads(file_data) json_data = json.loads(file_data)

View File

@ -5,6 +5,7 @@ from collections import Counter
from anytree import Node, RenderTree from anytree import Node, RenderTree
from machine_learning.decision_tree.helpers import path
from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs from machine_learning.decision_tree.data_set import training_set, test_set, attributes as attribs
@ -84,7 +85,7 @@ def main():
score += 1 score += 1
print(f'Accuracy: {score/len(test_set)}') print(f'Accuracy: {score/len(test_set)}')
with open('decision_tree.pkl', 'wb') as file: with open(f'{path}/decision_tree.pkl', 'wb') as file:
pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL) pickle.dump(tree_root, file, pickle.HIGHEST_PROTOCOL)

View File

@ -1,3 +1,4 @@
import os
from tile import Tile from tile import Tile
@ -10,3 +11,6 @@ def get_dataset_from_tile(tile: Tile):
'mine_type': tile.mine.mine_type, 'mine_type': tile.mine.mine_type,
'pressure_gt_two': True if tile.mine.pressure > 2 else False 'pressure_gt_two': True if tile.mine.pressure > 2 else False
} }
path = os.path.dirname(os.path.abspath(__file__))