implementacja sieci neuronowych

This commit is contained in:
Zuzanna Wójcik 2023-06-05 00:45:29 +02:00
parent 51d8327040
commit bb8f28ad40
8 changed files with 158 additions and 8 deletions

147
NeuralNetwork/prediction.py Normal file
View File

@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from torchvision.transforms import transforms
import numpy as np
from torch.autograd import Variable
from torchvision.models import squeezenet1_1
import torch.functional as F
from io import open
import os
from PIL import Image
import pathlib
import glob
from tkinter import Tk, Label
from PIL import Image, ImageTk
absolute_path = os.path.abspath('NeuralNetwork/src/train_images')
train_path = absolute_path
absolute_path = os.path.abspath('Images/Items_test')
pred_path = absolute_path
root=pathlib.Path(train_path)
classes=sorted([j.name.split('/')[-1] for j in root.iterdir()])
class DataModel(nn.Module):
def __init__(self, num_classes):
super(DataModel, self).__init__()
#input (batch=256, nr of channels rgb=3 , size=244x244)
# convolution
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
#shape (256, 12, 224x224)
# batch normalization
self.bn1 = nn.BatchNorm2d(num_features=12)
#shape (256, 12, 224x224)
self.reul1 = nn.ReLU()
self.pool=nn.MaxPool2d(kernel_size=2, stride=2)
# reduce image size by factor 2
# pooling window moves by 2 pixels at a time instead of 1
# shape (256, 12, 112x112)
self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(num_features=24)
self.reul2 = nn.ReLU()
# shape (256, 24, 112x112)
self.conv3 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=3, stride=1, padding=1)
#shape (256, 48, 112x112)
self.bn3 = nn.BatchNorm2d(num_features=48)
#shape (256, 48, 112x112)
self.reul3 = nn.ReLU()
# connected layer
self.fc = nn.Linear(in_features=48*112*112, out_features=num_classes)
def forward(self, input):
output = self.conv1(input)
output = self.bn1(output)
output = self.reul1(output)
output = self.pool(output)
output = self.conv2(output)
output = self.bn2(output)
output = self.reul2(output)
output = self.conv3(output)
output = self.bn3(output)
output = self.reul3(output)
# output shape matrix (256, 48, 112x112)
#print(output.shape)
#print(self.fc.weight.shape)
output = output.view(-1, 48*112*112)
output = self.fc(output)
return output
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, 'best_model.pth')
checkpoint=torch.load(file_path)
model = DataModel(num_classes=2)
model.load_state_dict(checkpoint)
model.eval()
transformer = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to (224, 224)
transforms.ToTensor(), # Convert images to tensors, 0-255 to 0-1
# transforms.RandomHorizontalFlip(), # 0.5 chance to flip the image
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
def prediction(img_path,transformer):
image=Image.open(img_path)
image_tensor=transformer(image).float()
image_tensor=image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
image_tensor.cuda()
input=Variable(image_tensor)
output=model(input)
index=output.data.numpy().argmax()
pred=classes[index]
return pred
def prediction_keys():
#funkcja zwracajaca sciezki do kazdego pliku w folderze w postaci listy
images_path=glob.glob(pred_path+'/*.jpg')
pred_list=[]
for i in images_path:
pred_list.append(i)
return pred_list
def predict_one(path):
#wyswietlanie obrazka po kazdym podniesieniu itemu
root = Tk()
root.title("Okno z obrazkiem")
image = Image.open(path)
photo = ImageTk.PhotoImage(image)
label = Label(root, image=photo)
label.pack()
root.mainloop()
#uruchamia sie funkcja spr czy obrazek to paczka czy list
pred_print = prediction(path,transformer)
print('Zdjecie jest: '+pred_print)
return pred_print

BIN
images/Items_test/test1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

BIN
images/Items_test/test2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
images/Items_test/test3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
images/Items_test/test4.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -3,9 +3,9 @@ import pygame
letter_pic = pygame.image.load("images/letter.png")
class Letter(pygame.sprite.Sprite):
def __init__(self, id):
def __init__(self, img_path):
super().__init__()
self.id = id
self.img_path = img_path
self.image = pygame.transform.scale(letter_pic, (40, 40))
self.rect = self.image.get_rect()
self.x = 430

12
main.py
View File

@ -9,6 +9,7 @@ import wyszukiwanie
import ekran
from grid import GridCellType, SearchGrid
import plansza
import NeuralNetwork.prediction as prediction
from plansza import a_pix, b_pix
@ -17,10 +18,11 @@ pygame.init()
def main():
wozek = Wozek()
p1 = Paczka('duzy', 12, 'narzedzia', False, True, False, any, any, any, any, any)
p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any)
l1 = Letter(1)
l2 = Letter(2)
pred_list = prediction.prediction_keys()
p1 = Paczka('duzy', 12, 'narzedzia', False, True, False, any, any, any, any, any, pred_list[3])
p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any, pred_list[1])
l1 = Letter(pred_list[0])
l2 = Letter(pred_list[2])
ekran.dodaj_na_rampe(p2, l1, p1, l2)
grid_points = SearchGrid()
@ -54,7 +56,7 @@ def main():
wozek.dynamic_wozek_picture()
przenoszony_item = wozek.storage[-1]
if isinstance(przenoszony_item,Paczka):
if (prediction.predict_one(przenoszony_item.img_path)=='package'):
## wozek jedzie odlozyc paczke na regal
przenoszona_paczka = przenoszony_item

View File

@ -4,7 +4,7 @@ import ekran
class Paczka(pygame.sprite.Sprite):
def __init__(self, rozmiar, waga, kategoria, priorytet, ksztalt, kruchosc, nadawca, adres, imie, nazwisko, telefon):
def __init__(self, rozmiar, waga, kategoria, priorytet, ksztalt, kruchosc, nadawca, adres, imie, nazwisko, telefon, img_path):
super().__init__()
self.rozmiar = rozmiar
self.image = pygame.image.load("images/paczka.png")
@ -31,6 +31,7 @@ class Paczka(pygame.sprite.Sprite):
self.priorytet = priorytet
self.ksztalt = ksztalt
self.kruchosc = kruchosc
self.img_path = img_path
self.x = 430
self.y = 400
self.label = Etykieta(nadawca, adres, imie, nazwisko, telefon, priorytet)