add neuron network in project

This commit is contained in:
Yahor Haleznik 2023-06-05 14:59:45 +02:00
parent 2287b2b09f
commit 85b7f120ae
2 changed files with 83 additions and 28 deletions

View File

@ -1,15 +1,15 @@
import asyncio
import os
import random
import time
from heapq import *
from enum import Enum, IntEnum
from queue import PriorityQueue
from collections import deque
from threading import Thread
from IC3 import tree
from recognition_v1 import proverka
import pygame
#from recognition_v1.proverka import *
from IC3 import tree
pygame.init()
BLACK = (0, 0, 0)
@ -19,7 +19,7 @@ WINDOW_DIMENSIONS = 900
BLOCK_SIZE = 60
ROCKS_NUMBER = 20
VEGETABLES_NUMBER = 20
VEGETABLES = ('Potato', 'Broccoli', 'Carrot', 'Onion')
VEGETABLES = ('Potato', 'Broccoli', 'Carrot', 'Capsicum')
BOARD_SIZE = int(WINDOW_DIMENSIONS / BLOCK_SIZE)
WATER_TANK_CAPACITY = 10
GAS_TANK_CAPACITY = 250
@ -27,14 +27,19 @@ SPAWN_POINT = (0, 0)
SKLEP_POINT = (14, 14)
TIMEOUT = 1
uio = proverka.VegebatlesRecognizer()
tractor_image = pygame.transform.scale(pygame.image.load("images/tractor_image.png"), (BLOCK_SIZE, BLOCK_SIZE))
rock_image = pygame.transform.scale(pygame.image.load("images/rock_image.png"), (BLOCK_SIZE, BLOCK_SIZE))
potato_image = pygame.transform.scale(pygame.image.load("images/potato.png"), (BLOCK_SIZE, BLOCK_SIZE))
carrot_image = pygame.transform.scale(pygame.image.load("images/carrot.png"), (BLOCK_SIZE, BLOCK_SIZE))
broccoli_image = pygame.transform.scale(pygame.image.load("images/broccoli.png"), (BLOCK_SIZE, BLOCK_SIZE))
onion_image = pygame.transform.scale(pygame.image.load("images/onion.png"), (BLOCK_SIZE, BLOCK_SIZE))
capsicum_image = pygame.transform.scale(pygame.image.load("images/capsicum.png"), (BLOCK_SIZE, BLOCK_SIZE))
unknown_image = pygame.transform.scale(pygame.image.load("images/unknown.png"), (BLOCK_SIZE, BLOCK_SIZE))
gas_station_image = pygame.transform.scale(pygame.image.load("images/gas_station.png"), (BLOCK_SIZE, BLOCK_SIZE))
gas_station_closed_image = pygame.transform.scale(pygame.image.load("images/gas_station_closed.png"), (BLOCK_SIZE, BLOCK_SIZE))
gas_station_closed_image = pygame.transform.scale(pygame.image.load("images/gas_station_closed.png"),
(BLOCK_SIZE, BLOCK_SIZE))
sklep_station_image = pygame.transform.scale(pygame.image.load("images/storage_open.png"), (BLOCK_SIZE, BLOCK_SIZE))
sklep_closed_station_image = pygame.transform.scale(pygame.image.load("images/storage_closed.png"),
(BLOCK_SIZE, BLOCK_SIZE))
@ -76,7 +81,7 @@ def draw_interface():
tractor.gas = GAS_TANK_CAPACITY
if (tractor.x, tractor.y) == SKLEP_POINT:
tractor.collected_vegetables = {vegetables.POTATO: 0, vegetables.BROCCOLI: 0, vegetables.CARROT: 0,
vegetables.ONION: 0}
vegetables.CAPSICUM: 0}
global sc
sc = pygame.display.set_mode((WINDOW_DIMENSIONS, WINDOW_DIMENSIONS))
@ -94,6 +99,9 @@ def draw_interface():
t2.setDaemon(True)
t2.start()
fl_running = True
determine_thread = Thread(target=grid.determine)
determine_thread.setDaemon(True)
determine_thread.start()
while fl_running:
draw_grid()
# region events
@ -122,9 +130,6 @@ def draw_interface():
# graph1.initialize_graph(grid)
class Direction(IntEnum):
UP = 0
RIGHT = 1
@ -136,7 +141,8 @@ class vegetables(Enum):
POTATO = 3
BROCCOLI = 4
CARROT = 5
ONION = 6
CAPSICUM = 6
UNKNOWN = 7
class types(Enum):
@ -145,7 +151,8 @@ class types(Enum):
POTATO = 3
BROCCOLI = 4
CARROT = 5
ONION = 6
CAPSICUM = 6
UNKNOWN = 7
class Grid:
@ -154,11 +161,13 @@ class Grid:
self.height = height
self.block_size = block_size
self.grid = [[types.EMPTY for col in range(BOARD_SIZE)] for row in range(BOARD_SIZE)]
self.photo_paths = [['' for col in range(BOARD_SIZE)] for row in range(BOARD_SIZE)]
self.vegetables_locations = []
self.initialize_grid()
self.is_gas_station_closed = False
self.is_storage_closed = False
def add_object(self, x, y, type_of_object: types):
def add_object(self, x, y, type_of_object):
if self.grid[x][y] == types.EMPTY:
self.grid[x][y] = type_of_object
return True
@ -176,7 +185,9 @@ class Grid:
for i in range(VEGETABLES_NUMBER):
x, y = random.randrange(0, BOARD_SIZE), random.randrange(0, BOARD_SIZE)
if self.grid[x][y] == types.EMPTY and (x, y) != (0, 0):
self.add_object(x, y, random.choice(list(vegetables)))
if self.add_object(x, y, vegetables.UNKNOWN): # random.choice(list(vegetables)))
self.vegetables_locations.append((x, y))
self.photo_paths[x][y] = get_random_photo_path()
else:
i -= 1
for i in range(ROCKS_NUMBER):
@ -186,6 +197,33 @@ class Grid:
else:
i -= 1
def determine(self):
for x, y in self.vegetables_locations:
if self.grid[x][y] == vegetables.UNKNOWN:
timeout = time.time() + 0.5
while True:
sc.blit(pygame.transform.scale(pygame.image.load(self.photo_paths[x][y]), (BLOCK_SIZE*2, BLOCK_SIZE*2)),
((x * BLOCK_SIZE) - BLOCK_SIZE/2, (y * BLOCK_SIZE) - BLOCK_SIZE/2))
# time.sleep(1/63)
if time.time() > timeout:
break
# time.sleep(1)
random_veg = random.choice(list(vegetables))
while random_veg == vegetables.UNKNOWN:
random_veg = random.choice(list(vegetables))
# self.grid[x][y] = random_veg
# self.grid[x][y] = uio.recognize(self.photo_paths[x][y])
aiResult = uio.recognize(self.photo_paths[x][y])
if aiResult == 'Broccoli':
self.grid[x][y] = vegetables.BROCCOLI
if aiResult == 'Capsicum':
self.grid[x][y] = vegetables.CAPSICUM
if aiResult == 'Carrot':
self.grid[x][y] = vegetables.CARROT
if aiResult == 'Potato':
self.grid[x][y] = vegetables.POTATO
class Graph:
def __init__(self, grid: Grid):
@ -237,7 +275,7 @@ class Tractor:
self.gas = GAS_TANK_CAPACITY
self.water = WATER_TANK_CAPACITY
self.collected_vegetables = {vegetables.POTATO: 0, vegetables.BROCCOLI: 0, vegetables.CARROT: 0,
vegetables.ONION: 0}
vegetables.CAPSICUM: 0}
self.image = pygame.transform.scale(pygame.image.load("images/tractor_image.png"), (BLOCK_SIZE, BLOCK_SIZE))
def rot_center(self, direc: Direction):
@ -286,7 +324,7 @@ def get_next_nodes(x, y, direction: Direction, grid: Grid):
else:
next_nodes.append((2, (x + way[0], y + way[1], new_direction)))
# print(x,y, direction, next_nodes, '\n')
# print(x,y, direction, next_nodes, '/n')
return next_nodes
@ -352,8 +390,10 @@ def updateDisplay(tractor: Tractor, grid: Grid):
sc.blit(carrot_image, (x * BLOCK_SIZE + 5, y * BLOCK_SIZE + 5))
elif grid.grid[x][y] == vegetables.BROCCOLI:
sc.blit(broccoli_image, (x * BLOCK_SIZE + 5, y * BLOCK_SIZE + 5))
elif grid.grid[x][y] == vegetables.ONION:
sc.blit(onion_image, (x * BLOCK_SIZE + 5, y * BLOCK_SIZE + 5))
elif grid.grid[x][y] == vegetables.CAPSICUM:
sc.blit(capsicum_image, (x * BLOCK_SIZE + 5, y * BLOCK_SIZE + 5))
elif grid.grid[x][y] == vegetables.UNKNOWN:
sc.blit(unknown_image, (x * BLOCK_SIZE + 5, y * BLOCK_SIZE + 5))
elif grid.grid[x][y] == types.ROCK:
sc.blit(rock_image, (x * BLOCK_SIZE, y * BLOCK_SIZE))
sc.blit(gas_station_image, (SPAWN_POINT[0] * BLOCK_SIZE, SPAWN_POINT[1] * BLOCK_SIZE))
@ -370,8 +410,8 @@ def updateDisplay(tractor: Tractor, grid: Grid):
vegetables_text = font.render(
'Potato: ' + str(tractor.collected_vegetables[vegetables.POTATO]) + ' Broccoli: ' + str(
tractor.collected_vegetables[vegetables.BROCCOLI]) + ' Carrot: ' + str(
tractor.collected_vegetables[vegetables.CARROT]) + ' Onion: ' + str(
tractor.collected_vegetables[vegetables.ONION]), True, WHITE, BLACK)
tractor.collected_vegetables[vegetables.CARROT]) + ' Capsicum: ' + str(
tractor.collected_vegetables[vegetables.CAPSICUM]), True, WHITE, BLACK)
vegetables_textrect = vegetables_text.get_rect()
vegetables_textrect.center = (WINDOW_DIMENSIONS // 2, WINDOW_DIMENSIONS - 30)
sc.blit(vegetables_text, vegetables_textrect)
@ -503,3 +543,18 @@ def close_open(grid: Grid):
time.sleep(TIMEOUT)
grid.is_gas_station_closed = bool(random.getrandbits(1))
grid.is_storage_closed = bool(random.getrandbits(1))
def get_random_photo_path():
dir_num = random.randint(1, 4)
image_dir = "C:/Users/KimD/PycharmProjects/sztin_gr.234798/Vegetable Images/train"
if dir_num == 1:
image_dir = "C:/Users/KimD/PycharmProjects/sztin_gr.234798/Vegetable Images/train/Broccoli"
if dir_num == 2:
image_dir = "C:/Users/KimD/PycharmProjects/sztin_gr.234798/Vegetable Images/train/Capsicum"
if dir_num == 3:
image_dir = "C:/Users/KimD/PycharmProjects/sztin_gr.234798/Vegetable Images/train/Carrot"
if dir_num == 4:
image_dir = "C:/Users/KimD/PycharmProjects/sztin_gr.234798/Vegetable Images/train/Potato"
return image_dir + '/' + random.choice(os.listdir(image_dir))

View File

@ -9,21 +9,21 @@ directory = "C:/Users/KimD/PycharmProjects/Traktor_V1/Vegetable Images/test"
class VegebatlesRecognizer:
model = keras.models.load_model("C:/Users/KimD/PycharmProjects/sztin_gr.234798/recognition_v1/mode1.h5")
def recognize(self, image_path) -> str:
model = keras.models.load_model("C:/Users/KimD/PycharmProjects/Traktor_V1/mode2.h5")
class_names = ['Bean', 'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cucumber', 'Potato', 'Pumpkin', 'Tomato']
class_names = ['Broccoli', 'Capsicum', 'Carrot', 'Potato']
img = cv2.imread(image_path)
# cv2.imshow("lala", img)
# cv2.waitKey(0)
img = (np.expand_dims(img, 0))
predictions = model.predict(img)[0].tolist()
predictions = self.model.predict(img)[0].tolist()
print(class_names)
print(predictions)
print(max(predictions))
print(predictions.index(max(predictions)))
# print(class_names)
# print(predictions)
# print(max(predictions))
# print(predictions.index(max(predictions)))
return class_names[predictions.index(max(predictions))]