diff --git a/src/import torch.py b/src/import torch.py new file mode 100644 index 00000000..a035a948 --- /dev/null +++ b/src/import torch.py @@ -0,0 +1,3 @@ +import torch +x = torch.rand(5, 3) +print(x) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 17495ae0..77fe2cdf 100644 --- a/src/main.py +++ b/src/main.py @@ -1,8 +1,7 @@ -import sys import pygame from field import Field import os -from config import TILE_SIZE, TICK_RATE +from config import TILE_SIZE, TICK_RATE, FINAL_X, FINAL_Y if __name__ == "__main__": pygame.init() @@ -14,10 +13,19 @@ if __name__ == "__main__": field = Field() running = True + while running: for event in pygame.event.get(): if event.type == pygame.QUIT: running = False + if event.type == pygame.MOUSEBUTTONDOWN: + x, y = pygame.mouse.get_pos() + print(f"Mouse clicked at: ({x}, {y})") + + grid_x = x // TILE_SIZE + grid_y = y // TILE_SIZE + + field.tractor.set_new_goal((grid_x, grid_y)) field.tractor.update() screen.fill(WHITE) diff --git a/src/tile.py b/src/tile.py index e4defae5..93079f9a 100644 --- a/src/tile.py +++ b/src/tile.py @@ -4,6 +4,10 @@ from kb import tractor_kb import pytholog as pl import random from config import TILE_SIZE, FREE_TILES +import torch +import torchvision.transforms as transforms +from PIL import Image + class Tile(pygame.sprite.Sprite): @@ -15,6 +19,7 @@ class Tile(pygame.sprite.Sprite): self.field = field self.set_type(tile_type) + print('tile type set as', tile_type) if self.type == 'water': self.stage = 'no_plant' self.water_level = 100 @@ -23,13 +28,44 @@ class Tile(pygame.sprite.Sprite): self.water_level = random.randint(1, 5) * 10 else: self.stage = 'planted' + self.stage = 'planted' # wczesniej to była self.faza = 'posadzono' ale stwierdzilem ze lepiej po angielsku??? + classes = [ + "bób", "brokuł", "brukselka", "burak", "cebula", + "cukinia", "dynia", "fasola", "groch", "jarmuż", + "kalafior", "kalarepa", "kapusta", "marchew", + "ogórek", "papryka", "pietruszka", "pomidor", + "por", "rzepa", "rzodkiewka", "sałata", "seler", + "szpinak", "ziemniak"] + + model = torch.load("veggies_recognition/best_model.pth") + + mean = [0.5322, 0.5120, 0.3696] + std = [0.2487, 0.2436, 0.2531] + + image_transforms = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) + ]) + + self.prediction = self.predict(model, image_transforms, self.image_path, classes) + self.rect = self.image.get_rect() self.rect.topleft = (x * TILE_SIZE, y * TILE_SIZE) def draw(self, surface): self.tiles.draw(surface) + + def get_random_image_from_folder(self): + folder_path = f"veggies_recognition/veggies/testing/{self.type}" + + files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))] + random_file = random.choice(files) + + image_path = folder_path + "/" + random_file + return image_path def set_type(self, type): self.type = type @@ -38,9 +74,30 @@ class Tile(pygame.sprite.Sprite): elif self.type == 'water': image_path = "images/water.png" else: - image_path = f"images/vegetables/{self.type}.png" + #image_path = f"images/vegetables/{self.type}.png" + image_path = self.get_random_image_from_folder() if not os.path.exists(image_path): image_path = "images/question.jpg" + self.image_path = image_path self.image = pygame.image.load(image_path).convert() self.image = pygame.transform.scale(self.image, (TILE_SIZE, TILE_SIZE)) + + def predict(self, model, image_transforms, image_path, classes): + model = model.eval() + image = Image.open(image_path) + image = image.convert("RGB") + image = image_transforms(image).float() + image = image.unsqueeze(0) + + output = model(image) + _, predicted = torch.max(output.data, 1) + + result = classes[predicted.item()] + + + if result == "ziemniak": + result = 'marchew' + return result + + diff --git a/src/tractor.py b/src/tractor.py index 218ea03f..ed2704b0 100644 --- a/src/tractor.py +++ b/src/tractor.py @@ -18,30 +18,33 @@ class Tractor(pygame.sprite.Sprite): def __init__(self, field): super().__init__ self.field = field - + self.water = 50 self.image = pygame.image.load('images/tractor/east.png').convert_alpha() self.image = pygame.transform.scale(self.image, (TILE_SIZE, TILE_SIZE)) self.rect = self.image.get_rect() - - self.direction = STARTING_DIRECTION - # TODO: enable tractor to start on other tile than (0,0) - self.start = (START_X, START_Y) - self.final = (FINAL_X, FINAL_Y) + self.direction = 'east' + self.start = (0, 0) + self.final = (0, 0) print('destination @', self.final[0], self.final[1]) self.rect.topleft = (self.start[0] * TILE_SIZE, self.start[1] * TILE_SIZE) - self.water = 50 - - # A-STAR - # came_from, total_cost = self.a_star() - # path = self.reconstruct_path(came_from) - # self.actions = self.recreate_actions(path) - # self.action_index = 0 + self.rect.topleft = (self.start[0] * TILE_SIZE, self.start[1] * TILE_SIZE) + self.actions = [] + self.action_index = 0 # DECISION TREE: self.label_encoders = {} self.load_decision_tree_model() + def set_new_goal(self, goal): + self.start = self.get_coordinates() + self.final = goal + came_from, total_cost = self.a_star() + path = self.reconstruct_path(came_from) + self.actions = self.recreate_actions(path) + self.action_index = 0 + print(f"New goal set to: {self.final}") + def load_decision_tree_model(self): data = pd.read_csv('tree.csv') @@ -67,7 +70,9 @@ class Tractor(pygame.sprite.Sprite): neighbors.append('grass') input_data = { - 'tile_type': self.get_current_tile().type, + #tutaj będzie dostawał informację ze zdjęcia + 'tile_type': self.get_current_tile().prediction, + #'tile_type': self.get_current_tile().type, 'water_level': self.get_current_tile().water_level, "plant_stage": self.get_current_tile().stage, "neighbor_N": neighbors[0], @@ -92,13 +97,11 @@ class Tractor(pygame.sprite.Sprite): def draw(self, surface): surface.blit(self.image, self.rect) - def get_coordinates(self): x = self.rect.x // TILE_SIZE y = self.rect.y // TILE_SIZE return (x,y) - def move(self): if self.direction == "north" and self.rect.y > 0: self.rect.y -= TILE_SIZE @@ -160,28 +163,16 @@ class Tractor(pygame.sprite.Sprite): self.move() else: self.move() - - def update(self): - # A STAR: - # if self.action_index == len(self.actions): - # return - # action = self.actions[self.action_index] - - # match (action): - # case ('move'): - # self.move() - # case ('left'): - # self.rotate('left') - # case ('right'): - # self.rotate('right') - - # DECISION TREE: + + def decision_tree(self): action = self.make_decision() - if (self.get_current_tile().type != 'grass' or self.get_current_tile().type == 'water'): action = 'move' + if (self.get_current_tile().type != 'grass' or self.get_current_tile().type == 'water'): action = 'nothing' self.prev_action = action + print("Decyzja podjęta przez drzewo decyzyjne: ", action) + match (action): - case ('move'): + case ('nothing'): pass #self.move_rotating() case ('harvest'): @@ -238,10 +229,37 @@ class Tractor(pygame.sprite.Sprite): self.get_current_tile().set_type('szpinak') case ('plant(ziemniak)'): self.get_current_tile().set_type('ziemniak') - self.move_2() - #self.action_index += 1 - print(action) + + def update(self): + # A STAR: + if self.action_index == len(self.actions): + return + action = self.actions[self.action_index] + + match (action): + case ('move'): + self.move() + case ('left'): + self.rotate('left') + case ('right'): + self.rotate('right') + + self.action_index += 1 + + if self.get_current_tile().type == "grass": + print("Co jest faktycznie: trawa") + elif self.get_current_tile().type == "water": + print("Co jest faktycznie: woda") + else: + print("Rozpoznano: ", self.get_current_tile().prediction) + print("Co jest faktycznie: ", self.get_current_tile().type) + print("\n") + + if self.get_coordinates() == self.final: + self.decision_tree() + return + def log_info(self): # print on what tile type the tractor is on @@ -355,13 +373,12 @@ class Tractor(pygame.sprite.Sprite): if current == self.final: break - # next_node: tuple[int, int] for next_node in self.neighboring_nodes(coordinates=current): enter_cost = self.cost_of_entering_node(coordinates=next_node) - new_cost: int = cost_so_far[current] + enter_cost + new_cost = cost_so_far[current] + enter_cost if next_node not in cost_so_far or new_cost < cost_so_far[next_node]: cost_so_far[next_node] = new_cost - priority = new_cost + self.manhattan_cost(current) + priority = new_cost + self.manhattan_cost(next_node) heapq.heappush(fringe, (priority, next_node)) came_from[next_node] = current