From 82ab417bfc648edfe4fcace83f14b0a166791255 Mon Sep 17 00:00:00 2001 From: Zofia Lorenc Date: Sun, 26 May 2024 22:34:44 +0200 Subject: [PATCH] added photo recognition --- src/import torch.py | 3 + src/main.py | 1 - src/tile.py | 59 +++++++++++++++++- src/tractor.py | 12 +++- src/veggies_recognition/predict.py | 56 ++++++++--------- .../{ => veggies}/marchew_118.jpg | Bin 6 files changed, 98 insertions(+), 33 deletions(-) create mode 100644 src/import torch.py rename src/veggies_recognition/{ => veggies}/marchew_118.jpg (100%) diff --git a/src/import torch.py b/src/import torch.py new file mode 100644 index 00000000..a035a948 --- /dev/null +++ b/src/import torch.py @@ -0,0 +1,3 @@ +import torch +x = torch.rand(5, 3) +print(x) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 17495ae0..49196381 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,3 @@ -import sys import pygame from field import Field import os diff --git a/src/tile.py b/src/tile.py index da49e101..57523490 100644 --- a/src/tile.py +++ b/src/tile.py @@ -4,6 +4,10 @@ from kb import tractor_kb import pytholog as pl import random from config import TILE_SIZE, FREE_TILES +import torch +import torchvision.transforms as transforms +from PIL import Image + class Tile(pygame.sprite.Sprite): @@ -26,15 +30,40 @@ class Tile(pygame.sprite.Sprite): self.set_type(random_vegetable) self.water_level = random.randint(1, 5) * 10 self.stage = 'planted' # wczesniej to była self.faza = 'posadzono' ale stwierdzilem ze lepiej po angielsku??? + + classes = [ + "bób", "brokuł", "brukselka", "burak", "cebula", + "cukinia", "dynia", "fasola", "groch", "jarmuż", + "kalafior", "kalarepa", "kapusta", "marchew", + "ogórek", "papryka", "pietruszka", "pomidor", + "por", "rzepa", "rzodkiewka", "sałata", "seler", + "szpinak", "ziemniak"] + + model = torch.load("veggies_recognition/best_model.pth") + + mean = [0.5322, 0.5120, 0.3696] + std = [0.2487, 0.2436, 0.2531] + + image_transforms = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) + ]) + + self.prediction = self.predict(model, image_transforms, self.image_path, classes) + + else: if random.randint(1, 10) % 3 == 0: self.set_type('water') self.water_level = 100 self.stage = 'no_plant' + self.prediction = 'water' else: self.set_type('grass') self.water_level = random.randint(1, 5) * 10 self.stage = 'no_plant' + self.prediction = 'grass' self.rect = self.image.get_rect() @@ -43,6 +72,17 @@ class Tile(pygame.sprite.Sprite): def draw(self, surface): self.tiles.draw(surface) + + def get_random_image_from_folder(self): + folder_path = f"veggies_recognition/veggies/testing/{self.type}" + + files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))] + random_file = random.choice(files) + + #image_path = os.path.join(folder_path, random_file) + image_path = folder_path + "/" + random_file + #print(image_path) + return image_path def set_type(self, type): self.type = type @@ -51,9 +91,26 @@ class Tile(pygame.sprite.Sprite): elif self.type == 'water': image_path = "images/water.png" else: - image_path = f"images/vegetables/{self.type}.png" + #image_path = f"images/vegetables/{self.type}.png" + image_path = self.get_random_image_from_folder() if not os.path.exists(image_path): image_path = "images/question.jpg" + self.image_path = image_path self.image = pygame.image.load(image_path).convert() self.image = pygame.transform.scale(self.image, (TILE_SIZE, TILE_SIZE)) + + def predict(self, model, image_transforms, image_path, classes): + model = model.eval() + image = Image.open(image_path) + image = image.convert("RGB") + image = image_transforms(image).float() + image = image.unsqueeze(0) + + output = model(image) + _, predicted = torch.max(output.data, 1) + + #print("Rozpoznano: ", classes[predicted.item()]) + return classes[predicted.item()] + + diff --git a/src/tractor.py b/src/tractor.py index 218ea03f..3fe1029e 100644 --- a/src/tractor.py +++ b/src/tractor.py @@ -67,7 +67,9 @@ class Tractor(pygame.sprite.Sprite): neighbors.append('grass') input_data = { - 'tile_type': self.get_current_tile().type, + #tutaj będzie dostawał informację ze zdjęcia + 'tile_type': self.get_current_tile().prediction, + #'tile_type': self.get_current_tile().type, 'water_level': self.get_current_tile().water_level, "plant_stage": self.get_current_tile().stage, "neighbor_N": neighbors[0], @@ -180,6 +182,7 @@ class Tractor(pygame.sprite.Sprite): if (self.get_current_tile().type != 'grass' or self.get_current_tile().type == 'water'): action = 'move' self.prev_action = action + match (action): case ('move'): pass @@ -240,9 +243,12 @@ class Tractor(pygame.sprite.Sprite): self.get_current_tile().set_type('ziemniak') self.move_2() #self.action_index += 1 - print(action) + print("Rozpoznano: ", self.get_current_tile().prediction) + print("Co jest faktycznie: ", self.get_current_tile().type) + print("\n") + return - + def log_info(self): # print on what tile type the tractor is on x = self.rect.x // TILE_SIZE diff --git a/src/veggies_recognition/predict.py b/src/veggies_recognition/predict.py index 12d8aa76..a81d4732 100644 --- a/src/veggies_recognition/predict.py +++ b/src/veggies_recognition/predict.py @@ -1,36 +1,36 @@ -import torch -import torchvision -import torchvision.transforms as transforms -from PIL import Image +# import torch +# import torchvision.transforms as transforms +# from PIL import Image -classes = [ - "bób", "brokuł", "brukselka", "burak", "cebula", - "cukinia", "dynia", "fasola", "groch", "jarmuż", - "kalafior", "kalarepa", "kapusta", "marchew", - "ogórek", "papryka", "pietruszka", "pomidor", - "por", "rzepa", "rzodkiewka", "sałata", "seler", - "szpinak", "ziemniak"] +# classes = [ +# "bób", "brokuł", "brukselka", "burak", "cebula", +# "cukinia", "dynia", "fasola", "groch", "jarmuż", +# "kalafior", "kalarepa", "kapusta", "marchew", +# "ogórek", "papryka", "pietruszka", "pomidor", +# "por", "rzepa", "rzodkiewka", "sałata", "seler", +# "szpinak", "ziemniak"] -model = torch.load("best_model.pth") +# model = torch.load("best_model.pth") -mean = [0.5322, 0.5120, 0.3696] -std = [0.2487, 0.2436, 0.2531] +# mean = [0.5322, 0.5120, 0.3696] +# std = [0.2487, 0.2436, 0.2531] -image_transforms = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) -]) +# image_transforms = transforms.Compose([ +# transforms.Resize((224, 224)), +# transforms.ToTensor(), +# transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) +# ]) -def predict(model, image_transforms, image_path, classes): - model = model.eval() - image = Image.open(image_path) - image = image_transforms(image).float() - image = image.unsqueeze(0) +# def predict(model, image_transforms, image_path, classes): +# model = model.eval() +# image = Image.open(image_path) +# print(image_path) +# image = image_transforms(image).float() +# image = image.unsqueeze(0) - output = model(image) - _, predicted = torch.max(output.data, 1) +# output = model(image) +# _, predicted = torch.max(output.data, 1) - print(classes[predicted.item()]) +# print(classes[predicted.item()]) -predict(model, image_transforms, "marchew_118.jpg", classes) \ No newline at end of file +# predict(model, image_transforms, "veggies/marchew_118.jpg", classes) \ No newline at end of file diff --git a/src/veggies_recognition/marchew_118.jpg b/src/veggies_recognition/veggies/marchew_118.jpg similarity index 100% rename from src/veggies_recognition/marchew_118.jpg rename to src/veggies_recognition/veggies/marchew_118.jpg