add neural network implementation to main

This commit is contained in:
s473558 2023-06-05 05:25:13 +02:00
parent da73e223e3
commit 86be72ba33
3 changed files with 24 additions and 6 deletions

20
main.py
View File

@ -5,6 +5,7 @@ import land
import tractor
import blocks
import astar_search
import neural_network.inference
from pygame.locals import *
@ -70,10 +71,13 @@ class Game:
clock = pygame.time.Clock()
move_tractor_event = pygame.USEREVENT + 1
pygame.time.set_timer(move_tractor_event, 100) # tractor moves every 1000 ms
pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
tractor_next_moves = []
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
veggies = dict()
veggies_debug = dict()
while running:
clock.tick(60) # manual fps control not to overwork the computer
for event in pygame.event.get():
@ -109,6 +113,20 @@ class Game:
#bandaid to know about stones
tractor_next_moves = astar_search_object.astarsearch(
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1]))]
if(current_veggie in veggies_debug):
veggies_debug[current_veggie]+=1
else:
veggies_debug[current_veggie] = 1
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2]))]
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
if predicted_veggie in veggies:
veggies[predicted_veggie]+=1
else:
veggies[predicted_veggie] = 1
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
else:
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
elif event.type == QUIT:

View File

@ -2,7 +2,7 @@ import torch
import cv2
import torchvision.transforms as transforms
import argparse
from model import CNNModel
from neural_network.model import CNNModel
# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input',
@ -22,7 +22,7 @@ def main(path):
# initialize the model and load the trained weights
model = CNNModel().to(device)
checkpoint = torch.load('outputs/model.pth', map_location=device)
checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

View File

@ -4,9 +4,9 @@ import torch.nn as nn
import torch.optim as optim
import time
from tqdm.auto import tqdm
from model import CNNModel
from datasets import train_loader, valid_loader
from utils import save_model, save_plots
from neural_network.model import CNNModel
from neural_network.datasets import train_loader, valid_loader
from neural_network.utils import save_model, save_plots
# construct the argument parser
parser = argparse.ArgumentParser()