diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..5ccd0950 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,3 @@ +{ + "recommendations": ["sbsnippets.pytorch-snippets"] +} diff --git a/src/import torch.py b/src/import torch.py new file mode 100644 index 00000000..a035a948 --- /dev/null +++ b/src/import torch.py @@ -0,0 +1,3 @@ +import torch +x = torch.rand(5, 3) +print(x) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 17495ae0..49196381 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,3 @@ -import sys import pygame from field import Field import os diff --git a/src/tile.py b/src/tile.py index da49e101..41ab871c 100644 --- a/src/tile.py +++ b/src/tile.py @@ -4,6 +4,10 @@ from kb import tractor_kb import pytholog as pl import random from config import TILE_SIZE, FREE_TILES +import torch +import torchvision.transforms as transforms +from PIL import Image + class Tile(pygame.sprite.Sprite): @@ -26,15 +30,39 @@ class Tile(pygame.sprite.Sprite): self.set_type(random_vegetable) self.water_level = random.randint(1, 5) * 10 self.stage = 'planted' # wczesniej to była self.faza = 'posadzono' ale stwierdzilem ze lepiej po angielsku??? + + classes = [ + "bób", "brokuł", "brukselka", "burak", "cebula", + "cukinia", "dynia", "fasola", "groch", "jarmuż", + "kalafior", "kalarepa", "kapusta", "marchew", + "ogórek", "papryka", "pietruszka", "pomidor", + "por", "rzepa", "rzodkiewka", "sałata", "seler", + "szpinak", "ziemniak"] + + model = torch.load("veggies_recognition/best_model.pth") + + mean = [0.5322, 0.5120, 0.3696] + std = [0.2487, 0.2436, 0.2531] + + image_transforms = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) + ]) + + self.prediction = self.predict(model, image_transforms, self.image_path, classes) + else: if random.randint(1, 10) % 3 == 0: self.set_type('water') self.water_level = 100 self.stage = 'no_plant' + self.prediction = 'water' else: self.set_type('grass') self.water_level = random.randint(1, 5) * 10 self.stage = 'no_plant' + self.prediction = 'grass' self.rect = self.image.get_rect() @@ -43,6 +71,15 @@ class Tile(pygame.sprite.Sprite): def draw(self, surface): self.tiles.draw(surface) + + def get_random_image_from_folder(self): + folder_path = f"veggies_recognition/veggies/testing/{self.type}" + + files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))] + random_file = random.choice(files) + + image_path = folder_path + "/" + random_file + return image_path def set_type(self, type): self.type = type @@ -51,9 +88,30 @@ class Tile(pygame.sprite.Sprite): elif self.type == 'water': image_path = "images/water.png" else: - image_path = f"images/vegetables/{self.type}.png" + #image_path = f"images/vegetables/{self.type}.png" + image_path = self.get_random_image_from_folder() if not os.path.exists(image_path): image_path = "images/question.jpg" + self.image_path = image_path self.image = pygame.image.load(image_path).convert() self.image = pygame.transform.scale(self.image, (TILE_SIZE, TILE_SIZE)) + + def predict(self, model, image_transforms, image_path, classes): + model = model.eval() + image = Image.open(image_path) + image = image.convert("RGB") + image = image_transforms(image).float() + image = image.unsqueeze(0) + + output = model(image) + _, predicted = torch.max(output.data, 1) + + result = classes[predicted.item()] + + + if result == "ziemniak": + result = 'marchew' + return result + + diff --git a/src/tractor.py b/src/tractor.py index 218ea03f..b070e0f0 100644 --- a/src/tractor.py +++ b/src/tractor.py @@ -67,7 +67,9 @@ class Tractor(pygame.sprite.Sprite): neighbors.append('grass') input_data = { - 'tile_type': self.get_current_tile().type, + #tutaj będzie dostawał informację ze zdjęcia + 'tile_type': self.get_current_tile().prediction, + #'tile_type': self.get_current_tile().type, 'water_level': self.get_current_tile().water_level, "plant_stage": self.get_current_tile().stage, "neighbor_N": neighbors[0], @@ -180,6 +182,7 @@ class Tractor(pygame.sprite.Sprite): if (self.get_current_tile().type != 'grass' or self.get_current_tile().type == 'water'): action = 'move' self.prev_action = action + match (action): case ('move'): pass @@ -240,9 +243,18 @@ class Tractor(pygame.sprite.Sprite): self.get_current_tile().set_type('ziemniak') self.move_2() #self.action_index += 1 - print(action) - return + if self.get_current_tile().type == "grass": + print("Co jest faktycznie: trawa") + elif self.get_current_tile().type == "water": + print("Co jest faktycznie: woda") + else: + print("Rozpoznano: ", self.get_current_tile().prediction) + print("Co jest faktycznie: ", self.get_current_tile().type) + print("\n") + + return + def log_info(self): # print on what tile type the tractor is on x = self.rect.x // TILE_SIZE