From 86be72ba33b374e512a716d2b4465fc1eb0ddb0a Mon Sep 17 00:00:00 2001 From: s473558 Date: Mon, 5 Jun 2023 05:25:13 +0200 Subject: [PATCH] add neural network implementation to main --- main.py | 20 +++++++++++++++++++- neural_network/inference.py | 4 ++-- neural_network/train.py | 6 +++--- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index e3f233f5..40b4b49d 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import land import tractor import blocks import astar_search +import neural_network.inference from pygame.locals import * @@ -70,10 +71,13 @@ class Game: clock = pygame.time.Clock() move_tractor_event = pygame.USEREVENT + 1 - pygame.time.set_timer(move_tractor_event, 100) # tractor moves every 1000 ms + pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms tractor_next_moves = [] astar_search_object = astar_search.Search(self.cell_size, self.cell_number) + veggies = dict() + veggies_debug = dict() + while running: clock.tick(60) # manual fps control not to overwork the computer for event in pygame.event.get(): @@ -109,6 +113,20 @@ class Game: #bandaid to know about stones tractor_next_moves = astar_search_object.astarsearch( [self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body) + current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1]))] + if(current_veggie in veggies_debug): + veggies_debug[current_veggie]+=1 + else: + veggies_debug[current_veggie] = 1 + + current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2]))] + predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}") + if predicted_veggie in veggies: + veggies[predicted_veggie]+=1 + else: + veggies[predicted_veggie] = 1 + print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies) + else: self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number) elif event.type == QUIT: diff --git a/neural_network/inference.py b/neural_network/inference.py index 3b3d1488..4de53890 100644 --- a/neural_network/inference.py +++ b/neural_network/inference.py @@ -2,7 +2,7 @@ import torch import cv2 import torchvision.transforms as transforms import argparse -from model import CNNModel +from neural_network.model import CNNModel # construct the argument parser parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', @@ -22,7 +22,7 @@ def main(path): # initialize the model and load the trained weights model = CNNModel().to(device) - checkpoint = torch.load('outputs/model.pth', map_location=device) + checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() diff --git a/neural_network/train.py b/neural_network/train.py index 6a9dc119..0ca51dc1 100644 --- a/neural_network/train.py +++ b/neural_network/train.py @@ -4,9 +4,9 @@ import torch.nn as nn import torch.optim as optim import time from tqdm.auto import tqdm -from model import CNNModel -from datasets import train_loader, valid_loader -from utils import save_model, save_plots +from neural_network.model import CNNModel +from neural_network.datasets import train_loader, valid_loader +from neural_network.utils import save_model, save_plots # construct the argument parser parser = argparse.ArgumentParser()