neural network config

This commit is contained in:
Mateusz Dokowicz 2023-06-14 19:59:23 +02:00
parent e2c2ea8f0e
commit af7027a90f
4 changed files with 51 additions and 36 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@
__pycache__ __pycache__
#PyCharm #PyCharm
.idea/ .idea/
AI_brain/model.h5

View File

@ -5,18 +5,19 @@ from tensorflow import keras
import cv2 import cv2
import random import random
#You can download model from https://uam-my.sharepoint.com/:f:/g/personal/pavbia_st_amu_edu_pl/EmBHjnETuk5LiCZS6xk7AnIBNsnffR3Sygf8EX2bhR1w4A # You can download model from https://uam-my.sharepoint.com/:f:/g/personal/pavbia_st_amu_edu_pl/EmBHjnETuk5LiCZS6xk7AnIBNsnffR3Sygf8EX2bhR1w4A
#Change the path to model + to datasets (string 12 + strings 35,41,47,53) # Change the path to model + to datasets (string 12 + strings 35,41,47,53)
class VacuumRecognizer: class VacuumRecognizer:
model = keras.models.load_model('AI_brain\model.h5') #Neuron model path model = keras.models.load_model("AI_brain\model.h5") # Neuron model path
def recognize(self, image_path) -> str: def recognize(self, image_path) -> str:
class_names = ['Banana', 'Cat', 'Earings', 'Plant'] class_names = ["Banana", "Cat", "Earings", "Plant"]
img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE) img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE)
cv2.waitKey(0) cv2.waitKey(0)
img = (np.expand_dims(img, 0)) img = np.expand_dims(img, 0)
predictions = self.model.predict(img)[0].tolist() predictions = self.model.predict(img)[0].tolist()
@ -31,31 +32,43 @@ class VacuumRecognizer:
return class_names[predictions.index(max(predictions))] return class_names[predictions.index(max(predictions))]
def get_random_dir(self, type) -> str: def get_random_dir(self, type) -> str:
if type == 'Plant': if type == "Plant":
plant_image_paths = 'C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Plant' #Plant dataset path plant_image_paths = "C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Plant" # Plant dataset path
plant_dirs = os.listdir(plant_image_paths) plant_dirs = os.listdir(plant_image_paths)
full_path = plant_image_paths + '\\' + plant_dirs[random.randint(0, len(plant_dirs)-1)] full_path = (
plant_image_paths
+ "\\"
+ plant_dirs[random.randint(0, len(plant_dirs) - 1)]
)
print(full_path) print(full_path)
return full_path return full_path
elif type == 'Earings': elif type == "Earings":
earnings_image_paths = 'C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Earings' #Earings dataset path earnings_image_paths = "C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Earings" # Earings dataset path
earning_dirs = os.listdir(earnings_image_paths) earning_dirs = os.listdir(earnings_image_paths)
full_path = earnings_image_paths + '\\' + earning_dirs[random.randint(0, len(earning_dirs)-1)] full_path = (
earnings_image_paths
+ "\\"
+ earning_dirs[random.randint(0, len(earning_dirs) - 1)]
)
print(full_path) print(full_path)
return full_path return full_path
elif type == 'Banana': elif type == "Banana":
banana_image_paths = 'C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Banana' #Banana dataset path banana_image_paths = "C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Banana" # Banana dataset path
banana_dirs = os.listdir(banana_image_paths) banana_dirs = os.listdir(banana_image_paths)
full_path = banana_image_paths + '\\' + banana_dirs[random.randint(0, len(banana_dirs)-1)] full_path = (
banana_image_paths
+ "\\"
+ banana_dirs[random.randint(0, len(banana_dirs) - 1)]
)
print(full_path) print(full_path)
return full_path return full_path
elif type == 'Cat': elif type == "Cat":
cat_image_paths = 'C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Cat' #Cat dataset path cat_image_paths = "C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Cat" # Cat dataset path
cat_dir = os.listdir(cat_image_paths) cat_dir = os.listdir(cat_image_paths)
#For testing the neuron model # For testing the neuron model
'''image_paths = [] """image_paths = []
image_paths.append('C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Banana') image_paths.append('C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Banana')
image_paths.append('C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Cat') image_paths.append('C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Cat')
image_paths.append('C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Earings') image_paths.append('C:\\Users\\Pavel\\Desktop\\AI\\Machine_learning_2023\\AI_brain\\Image_datasetJPGnewBnW\\check\\Earings')
@ -65,4 +78,4 @@ uio = VacuumRecognizer()
for image_path in image_paths: for image_path in image_paths:
dirs = os.listdir(image_path) dirs = os.listdir(image_path)
for i in range(3): for i in range(3):
print(uio.recognize(image_path + '\\' + dirs[random.randint(0, len(dirs)-1)]))''' print(uio.recognize(image_path + '\\' + dirs[random.randint(0, len(dirs)-1)]))"""

View File

@ -9,4 +9,4 @@ NumberOfEarrings = 3
NumberOfPlants = 5 NumberOfPlants = 5
[NEURAL_NETWORK] [NEURAL_NETWORK]
is_nural_network_off = True is_neural_network_off = True

31
main.py
View File

@ -3,6 +3,9 @@ from random import randint
import pygame import pygame
import configparser import configparser
config = configparser.ConfigParser()
config.read("config.ini")
from domain.commands.random_cat_move_command import RandomCatMoveCommand from domain.commands.random_cat_move_command import RandomCatMoveCommand
from domain.commands.vacuum_move_command import VacuumMoveCommand from domain.commands.vacuum_move_command import VacuumMoveCommand
from domain.entities.cat import Cat from domain.entities.cat import Cat
@ -13,17 +16,15 @@ from domain.entities.earring import Earring
from domain.entities.docking_station import Doc_Station from domain.entities.docking_station import Doc_Station
from domain.world import World from domain.world import World
from view.renderer import Renderer from view.renderer import Renderer
from AI_brain.image_recognition import VacuumRecognizer
if not config.getboolean("NEURAL_NETWORK", "is_neural_network_off"):
from AI_brain.image_recognition import VacuumRecognizer
# from AI_brain.movement import GoAnyDirectionBFS, State # from AI_brain.movement import GoAnyDirectionBFS, State
# from AI_brain.rotate_and_go_bfs import RotateAndGoBFS, State # from AI_brain.rotate_and_go_bfs import RotateAndGoBFS, State
from AI_brain.rotate_and_go_aStar import RotateAndGoAStar, State from AI_brain.rotate_and_go_aStar import RotateAndGoAStar, State
config = configparser.ConfigParser()
config.read("config.ini")
class Main: class Main:
def __init__(self): def __init__(self):
tiles_x = 10 tiles_x = 10
@ -145,7 +146,7 @@ class Main:
def generate_world(tiles_x: int, tiles_y: int) -> World: def generate_world(tiles_x: int, tiles_y: int) -> World:
if config.getboolean("NEURAL_NETWORK", "is_nural_network_off"): if config.getboolean("NEURAL_NETWORK", "is_neural_network_off"):
world = World(tiles_x, tiles_y) world = World(tiles_x, tiles_y)
for _ in range(config.getint("CONSTANT", "NumberOfBananas")): for _ in range(config.getint("CONSTANT", "NumberOfBananas")):
temp_x = randint(0, tiles_x - 1) temp_x = randint(0, tiles_x - 1)
@ -165,15 +166,16 @@ def generate_world(tiles_x: int, tiles_y: int) -> World:
world.add_entity(Earring(5, 5)) world.add_entity(Earring(5, 5))
world.add_entity(Earring(4, 6)) world.add_entity(Earring(4, 6))
else: else:
def world_adder(x,y,object,style=None):
def world_adder(x, y, object, style=None):
print(object) print(object)
if object == 'Plant': if object == "Plant":
world.add_entity(Entity(x, y, f"PLANT{randint(1, 3)}")) world.add_entity(Entity(x, y, f"PLANT{randint(1, 3)}"))
if object == 'Earings': if object == "Earings":
world.add_entity(Earring(x, y)) world.add_entity(Earring(x, y))
if object == 'Banana': if object == "Banana":
world.add_entity(Garbage(temp_x, temp_y)) world.add_entity(Garbage(temp_x, temp_y))
if object == 'Cat' and config.getboolean("APP", "cat"): if object == "Cat" and config.getboolean("APP", "cat"):
world.add_entity(Cat(x, y)) world.add_entity(Cat(x, y))
neural_network = VacuumRecognizer() neural_network = VacuumRecognizer()
@ -188,22 +190,21 @@ def generate_world(tiles_x: int, tiles_y: int) -> World:
for _ in range(config.getint("CONSTANT", "NumberOfPlants")): for _ in range(config.getint("CONSTANT", "NumberOfPlants")):
temp_x = randint(0, tiles_x - 1) temp_x = randint(0, tiles_x - 1)
temp_y = randint(0, tiles_y - 1) temp_y = randint(0, tiles_y - 1)
path = VacuumRecognizer.get_random_dir(neural_network,'Plant') path = VacuumRecognizer.get_random_dir(neural_network, "Plant")
world_adder(temp_x, temp_y, neural_network.recognize(path)) world_adder(temp_x, temp_y, neural_network.recognize(path))
for _ in range(config.getint("CONSTANT", "NumberOfEarrings")): for _ in range(config.getint("CONSTANT", "NumberOfEarrings")):
temp_x = randint(0, tiles_x - 1) temp_x = randint(0, tiles_x - 1)
temp_y = randint(0, tiles_y - 1) temp_y = randint(0, tiles_y - 1)
path = VacuumRecognizer.get_random_dir(neural_network,'Earings') path = VacuumRecognizer.get_random_dir(neural_network, "Earings")
world_adder(temp_x, temp_y, neural_network.recognize(path)) world_adder(temp_x, temp_y, neural_network.recognize(path))
for _ in range(config.getint("CONSTANT", "NumberOfBananas")): for _ in range(config.getint("CONSTANT", "NumberOfBananas")):
temp_x = randint(0, tiles_x - 1) temp_x = randint(0, tiles_x - 1)
temp_y = randint(0, tiles_y - 1) temp_y = randint(0, tiles_y - 1)
path = VacuumRecognizer.get_random_dir(neural_network,'Banana') path = VacuumRecognizer.get_random_dir(neural_network, "Banana")
world_adder(temp_x, temp_y, neural_network.recognize(path)) world_adder(temp_x, temp_y, neural_network.recognize(path))
for x in range(world.width): for x in range(world.width):
for y in range(world.height): for y in range(world.height):
if world.is_garbage_at(x, y): if world.is_garbage_at(x, y):