add neural network implementation to main

This commit is contained in:
s473558 2023-06-05 05:25:13 +02:00
parent da73e223e3
commit 86be72ba33
3 changed files with 24 additions and 6 deletions

20
main.py
View File

@ -5,6 +5,7 @@ import land
import tractor import tractor
import blocks import blocks
import astar_search import astar_search
import neural_network.inference
from pygame.locals import * from pygame.locals import *
@ -70,10 +71,13 @@ class Game:
clock = pygame.time.Clock() clock = pygame.time.Clock()
move_tractor_event = pygame.USEREVENT + 1 move_tractor_event = pygame.USEREVENT + 1
pygame.time.set_timer(move_tractor_event, 100) # tractor moves every 1000 ms pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
tractor_next_moves = [] tractor_next_moves = []
astar_search_object = astar_search.Search(self.cell_size, self.cell_number) astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
veggies = dict()
veggies_debug = dict()
while running: while running:
clock.tick(60) # manual fps control not to overwork the computer clock.tick(60) # manual fps control not to overwork the computer
for event in pygame.event.get(): for event in pygame.event.get():
@ -109,6 +113,20 @@ class Game:
#bandaid to know about stones #bandaid to know about stones
tractor_next_moves = astar_search_object.astarsearch( tractor_next_moves = astar_search_object.astarsearch(
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body) [self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1]))]
if(current_veggie in veggies_debug):
veggies_debug[current_veggie]+=1
else:
veggies_debug[current_veggie] = 1
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2]))]
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
if predicted_veggie in veggies:
veggies[predicted_veggie]+=1
else:
veggies[predicted_veggie] = 1
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
else: else:
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number) self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
elif event.type == QUIT: elif event.type == QUIT:

View File

@ -2,7 +2,7 @@ import torch
import cv2 import cv2
import torchvision.transforms as transforms import torchvision.transforms as transforms
import argparse import argparse
from model import CNNModel from neural_network.model import CNNModel
# construct the argument parser # construct the argument parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', parser.add_argument('-i', '--input',
@ -22,7 +22,7 @@ def main(path):
# initialize the model and load the trained weights # initialize the model and load the trained weights
model = CNNModel().to(device) model = CNNModel().to(device)
checkpoint = torch.load('outputs/model.pth', map_location=device) checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint['model_state_dict'])
model.eval() model.eval()

View File

@ -4,9 +4,9 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import time import time
from tqdm.auto import tqdm from tqdm.auto import tqdm
from model import CNNModel from neural_network.model import CNNModel
from datasets import train_loader, valid_loader from neural_network.datasets import train_loader, valid_loader
from utils import save_model, save_plots from neural_network.utils import save_model, save_plots
# construct the argument parser # construct the argument parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()