diff --git a/NeuralNetwork/prediction.py b/NeuralNetwork/prediction.py new file mode 100644 index 0000000..b2a63a4 --- /dev/null +++ b/NeuralNetwork/prediction.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +from torchvision.transforms import transforms +import numpy as np +from torch.autograd import Variable +from torchvision.models import squeezenet1_1 +import torch.functional as F +from io import open +import os +from PIL import Image +import pathlib +import glob +from tkinter import Tk, Label +from PIL import Image, ImageTk + +absolute_path = os.path.abspath('NeuralNetwork/src/train_images') +train_path = absolute_path +absolute_path = os.path.abspath('Images/Items_test') +pred_path = absolute_path + +root=pathlib.Path(train_path) +classes=sorted([j.name.split('/')[-1] for j in root.iterdir()]) + + +class DataModel(nn.Module): + def __init__(self, num_classes): + super(DataModel, self).__init__() + #input (batch=256, nr of channels rgb=3 , size=244x244) + + # convolution + self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1) + #shape (256, 12, 224x224) + + # batch normalization + self.bn1 = nn.BatchNorm2d(num_features=12) + #shape (256, 12, 224x224) + self.reul1 = nn.ReLU() + + self.pool=nn.MaxPool2d(kernel_size=2, stride=2) + # reduce image size by factor 2 + # pooling window moves by 2 pixels at a time instead of 1 + # shape (256, 12, 112x112) + + + + self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(num_features=24) + self.reul2 = nn.ReLU() + # shape (256, 24, 112x112) + + self.conv3 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=3, stride=1, padding=1) + #shape (256, 48, 112x112) + self.bn3 = nn.BatchNorm2d(num_features=48) + #shape (256, 48, 112x112) + self.reul3 = nn.ReLU() + + # connected layer + self.fc = nn.Linear(in_features=48*112*112, out_features=num_classes) + + def forward(self, input): + output = self.conv1(input) + output = self.bn1(output) + output = self.reul1(output) + + output = self.pool(output) + output = self.conv2(output) + output = self.bn2(output) + output = self.reul2(output) + + output = self.conv3(output) + output = self.bn3(output) + output = self.reul3(output) + + # output shape matrix (256, 48, 112x112) + #print(output.shape) + #print(self.fc.weight.shape) + + output = output.view(-1, 48*112*112) + output = self.fc(output) + + return output + +script_dir = os.path.dirname(os.path.abspath(__file__)) +file_path = os.path.join(script_dir, 'best_model.pth') +checkpoint=torch.load(file_path) +model = DataModel(num_classes=2) +model.load_state_dict(checkpoint) +model.eval() + +transformer = transforms.Compose([ + transforms.Resize((224, 224)), # Resize images to (224, 224) + transforms.ToTensor(), # Convert images to tensors, 0-255 to 0-1 + # transforms.RandomHorizontalFlip(), # 0.5 chance to flip the image + transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]) +]) + +def prediction(img_path,transformer): + + image=Image.open(img_path) + + image_tensor=transformer(image).float() + + image_tensor=image_tensor.unsqueeze_(0) + + if torch.cuda.is_available(): + image_tensor.cuda() + + input=Variable(image_tensor) + + output=model(input) + + index=output.data.numpy().argmax() + + pred=classes[index] + + return pred + +def prediction_keys(): + + #funkcja zwracajaca sciezki do kazdego pliku w folderze w postaci listy + + images_path=glob.glob(pred_path+'/*.jpg') + + pred_list=[] + + for i in images_path: + pred_list.append(i) + + return pred_list + +def predict_one(path): + + #wyswietlanie obrazka po kazdym podniesieniu itemu + root = Tk() + root.title("Okno z obrazkiem") + + image = Image.open(path) + photo = ImageTk.PhotoImage(image) + label = Label(root, image=photo) + label.pack() + + root.mainloop() + + #uruchamia sie funkcja spr czy obrazek to paczka czy list + pred_print = prediction(path,transformer) + print('Zdjecie jest: '+pred_print) + return pred_print \ No newline at end of file diff --git a/images/Items_test/test1.jpg b/images/Items_test/test1.jpg new file mode 100644 index 0000000..633b417 Binary files /dev/null and b/images/Items_test/test1.jpg differ diff --git a/images/Items_test/test2.jpg b/images/Items_test/test2.jpg new file mode 100644 index 0000000..4fe00d6 Binary files /dev/null and b/images/Items_test/test2.jpg differ diff --git a/images/Items_test/test3.jpg b/images/Items_test/test3.jpg new file mode 100644 index 0000000..2e4d2e5 Binary files /dev/null and b/images/Items_test/test3.jpg differ diff --git a/images/Items_test/test4.jpg b/images/Items_test/test4.jpg new file mode 100644 index 0000000..63c28eb Binary files /dev/null and b/images/Items_test/test4.jpg differ diff --git a/letter.py b/letter.py index 98c1deb..b5b24da 100644 --- a/letter.py +++ b/letter.py @@ -3,9 +3,9 @@ import pygame letter_pic = pygame.image.load("images/letter.png") class Letter(pygame.sprite.Sprite): - def __init__(self, id): + def __init__(self, img_path): super().__init__() - self.id = id + self.img_path = img_path self.image = pygame.transform.scale(letter_pic, (40, 40)) self.rect = self.image.get_rect() self.x = 430 diff --git a/main.py b/main.py index b78c014..eede3df 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ import wyszukiwanie import ekran from grid import GridCellType, SearchGrid import plansza +import NeuralNetwork.prediction as prediction from plansza import a_pix, b_pix @@ -17,10 +18,11 @@ pygame.init() def main(): wozek = Wozek() - p1 = Paczka('duzy', 12, 'narzedzia', False, True, False, any, any, any, any, any) - p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any) - l1 = Letter(1) - l2 = Letter(2) + pred_list = prediction.prediction_keys() + p1 = Paczka('duzy', 12, 'narzedzia', False, True, False, any, any, any, any, any, pred_list[3]) + p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any, pred_list[1]) + l1 = Letter(pred_list[0]) + l2 = Letter(pred_list[2]) ekran.dodaj_na_rampe(p2, l1, p1, l2) grid_points = SearchGrid() @@ -54,7 +56,7 @@ def main(): wozek.dynamic_wozek_picture() przenoszony_item = wozek.storage[-1] - if isinstance(przenoszony_item,Paczka): + if (prediction.predict_one(przenoszony_item.img_path)=='package'): ## wozek jedzie odlozyc paczke na regal przenoszona_paczka = przenoszony_item diff --git a/paczka.py b/paczka.py index aea865f..51e0595 100644 --- a/paczka.py +++ b/paczka.py @@ -4,7 +4,7 @@ import ekran class Paczka(pygame.sprite.Sprite): - def __init__(self, rozmiar, waga, kategoria, priorytet, ksztalt, kruchosc, nadawca, adres, imie, nazwisko, telefon): + def __init__(self, rozmiar, waga, kategoria, priorytet, ksztalt, kruchosc, nadawca, adres, imie, nazwisko, telefon, img_path): super().__init__() self.rozmiar = rozmiar self.image = pygame.image.load("images/paczka.png") @@ -31,6 +31,7 @@ class Paczka(pygame.sprite.Sprite): self.priorytet = priorytet self.ksztalt = ksztalt self.kruchosc = kruchosc + self.img_path = img_path self.x = 430 self.y = 400 self.label = Etykieta(nadawca, adres, imie, nazwisko, telefon, priorytet)