added photo recognition

This commit is contained in:
Zofia Lorenc 2024-05-26 22:34:44 +02:00
parent 4955e737c5
commit 82ab417bfc
6 changed files with 98 additions and 33 deletions

3
src/import torch.py Normal file
View File

@ -0,0 +1,3 @@
import torch
x = torch.rand(5, 3)
print(x)

View File

@ -1,4 +1,3 @@
import sys
import pygame import pygame
from field import Field from field import Field
import os import os

View File

@ -4,6 +4,10 @@ from kb import tractor_kb
import pytholog as pl import pytholog as pl
import random import random
from config import TILE_SIZE, FREE_TILES from config import TILE_SIZE, FREE_TILES
import torch
import torchvision.transforms as transforms
from PIL import Image
class Tile(pygame.sprite.Sprite): class Tile(pygame.sprite.Sprite):
@ -26,15 +30,40 @@ class Tile(pygame.sprite.Sprite):
self.set_type(random_vegetable) self.set_type(random_vegetable)
self.water_level = random.randint(1, 5) * 10 self.water_level = random.randint(1, 5) * 10
self.stage = 'planted' # wczesniej to była self.faza = 'posadzono' ale stwierdzilem ze lepiej po angielsku??? self.stage = 'planted' # wczesniej to była self.faza = 'posadzono' ale stwierdzilem ze lepiej po angielsku???
classes = [
"bób", "brokuł", "brukselka", "burak", "cebula",
"cukinia", "dynia", "fasola", "groch", "jarmuż",
"kalafior", "kalarepa", "kapusta", "marchew",
"ogórek", "papryka", "pietruszka", "pomidor",
"por", "rzepa", "rzodkiewka", "sałata", "seler",
"szpinak", "ziemniak"]
model = torch.load("veggies_recognition/best_model.pth")
mean = [0.5322, 0.5120, 0.3696]
std = [0.2487, 0.2436, 0.2531]
image_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean),torch.Tensor(std))
])
self.prediction = self.predict(model, image_transforms, self.image_path, classes)
else: else:
if random.randint(1, 10) % 3 == 0: if random.randint(1, 10) % 3 == 0:
self.set_type('water') self.set_type('water')
self.water_level = 100 self.water_level = 100
self.stage = 'no_plant' self.stage = 'no_plant'
self.prediction = 'water'
else: else:
self.set_type('grass') self.set_type('grass')
self.water_level = random.randint(1, 5) * 10 self.water_level = random.randint(1, 5) * 10
self.stage = 'no_plant' self.stage = 'no_plant'
self.prediction = 'grass'
self.rect = self.image.get_rect() self.rect = self.image.get_rect()
@ -43,6 +72,17 @@ class Tile(pygame.sprite.Sprite):
def draw(self, surface): def draw(self, surface):
self.tiles.draw(surface) self.tiles.draw(surface)
def get_random_image_from_folder(self):
folder_path = f"veggies_recognition/veggies/testing/{self.type}"
files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
random_file = random.choice(files)
#image_path = os.path.join(folder_path, random_file)
image_path = folder_path + "/" + random_file
#print(image_path)
return image_path
def set_type(self, type): def set_type(self, type):
self.type = type self.type = type
@ -51,9 +91,26 @@ class Tile(pygame.sprite.Sprite):
elif self.type == 'water': elif self.type == 'water':
image_path = "images/water.png" image_path = "images/water.png"
else: else:
image_path = f"images/vegetables/{self.type}.png" #image_path = f"images/vegetables/{self.type}.png"
image_path = self.get_random_image_from_folder()
if not os.path.exists(image_path): if not os.path.exists(image_path):
image_path = "images/question.jpg" image_path = "images/question.jpg"
self.image_path = image_path
self.image = pygame.image.load(image_path).convert() self.image = pygame.image.load(image_path).convert()
self.image = pygame.transform.scale(self.image, (TILE_SIZE, TILE_SIZE)) self.image = pygame.transform.scale(self.image, (TILE_SIZE, TILE_SIZE))
def predict(self, model, image_transforms, image_path, classes):
model = model.eval()
image = Image.open(image_path)
image = image.convert("RGB")
image = image_transforms(image).float()
image = image.unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
#print("Rozpoznano: ", classes[predicted.item()])
return classes[predicted.item()]

View File

@ -67,7 +67,9 @@ class Tractor(pygame.sprite.Sprite):
neighbors.append('grass') neighbors.append('grass')
input_data = { input_data = {
'tile_type': self.get_current_tile().type, #tutaj będzie dostawał informację ze zdjęcia
'tile_type': self.get_current_tile().prediction,
#'tile_type': self.get_current_tile().type,
'water_level': self.get_current_tile().water_level, 'water_level': self.get_current_tile().water_level,
"plant_stage": self.get_current_tile().stage, "plant_stage": self.get_current_tile().stage,
"neighbor_N": neighbors[0], "neighbor_N": neighbors[0],
@ -180,6 +182,7 @@ class Tractor(pygame.sprite.Sprite):
if (self.get_current_tile().type != 'grass' or self.get_current_tile().type == 'water'): action = 'move' if (self.get_current_tile().type != 'grass' or self.get_current_tile().type == 'water'): action = 'move'
self.prev_action = action self.prev_action = action
match (action): match (action):
case ('move'): case ('move'):
pass pass
@ -240,9 +243,12 @@ class Tractor(pygame.sprite.Sprite):
self.get_current_tile().set_type('ziemniak') self.get_current_tile().set_type('ziemniak')
self.move_2() self.move_2()
#self.action_index += 1 #self.action_index += 1
print(action) print("Rozpoznano: ", self.get_current_tile().prediction)
print("Co jest faktycznie: ", self.get_current_tile().type)
print("\n")
return return
def log_info(self): def log_info(self):
# print on what tile type the tractor is on # print on what tile type the tractor is on
x = self.rect.x // TILE_SIZE x = self.rect.x // TILE_SIZE

View File

@ -1,36 +1,36 @@
import torch # import torch
import torchvision # import torchvision.transforms as transforms
import torchvision.transforms as transforms # from PIL import Image
from PIL import Image
classes = [ # classes = [
"bób", "brokuł", "brukselka", "burak", "cebula", # "bób", "brokuł", "brukselka", "burak", "cebula",
"cukinia", "dynia", "fasola", "groch", "jarmuż", # "cukinia", "dynia", "fasola", "groch", "jarmuż",
"kalafior", "kalarepa", "kapusta", "marchew", # "kalafior", "kalarepa", "kapusta", "marchew",
"ogórek", "papryka", "pietruszka", "pomidor", # "ogórek", "papryka", "pietruszka", "pomidor",
"por", "rzepa", "rzodkiewka", "sałata", "seler", # "por", "rzepa", "rzodkiewka", "sałata", "seler",
"szpinak", "ziemniak"] # "szpinak", "ziemniak"]
model = torch.load("best_model.pth") # model = torch.load("best_model.pth")
mean = [0.5322, 0.5120, 0.3696] # mean = [0.5322, 0.5120, 0.3696]
std = [0.2487, 0.2436, 0.2531] # std = [0.2487, 0.2436, 0.2531]
image_transforms = transforms.Compose([ # image_transforms = transforms.Compose([
transforms.Resize((224, 224)), # transforms.Resize((224, 224)),
transforms.ToTensor(), # transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) # transforms.Normalize(torch.Tensor(mean),torch.Tensor(std))
]) # ])
def predict(model, image_transforms, image_path, classes): # def predict(model, image_transforms, image_path, classes):
model = model.eval() # model = model.eval()
image = Image.open(image_path) # image = Image.open(image_path)
image = image_transforms(image).float() # print(image_path)
image = image.unsqueeze(0) # image = image_transforms(image).float()
# image = image.unsqueeze(0)
output = model(image) # output = model(image)
_, predicted = torch.max(output.data, 1) # _, predicted = torch.max(output.data, 1)
print(classes[predicted.item()]) # print(classes[predicted.item()])
predict(model, image_transforms, "marchew_118.jpg", classes) # predict(model, image_transforms, "veggies/marchew_118.jpg", classes)

View File

Before

Width:  |  Height:  |  Size: 9.2 KiB

After

Width:  |  Height:  |  Size: 9.2 KiB