add neural network implementation to main
This commit is contained in:
parent
da73e223e3
commit
86be72ba33
20
main.py
20
main.py
@ -5,6 +5,7 @@ import land
|
||||
import tractor
|
||||
import blocks
|
||||
import astar_search
|
||||
import neural_network.inference
|
||||
from pygame.locals import *
|
||||
|
||||
|
||||
@ -70,10 +71,13 @@ class Game:
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
move_tractor_event = pygame.USEREVENT + 1
|
||||
pygame.time.set_timer(move_tractor_event, 100) # tractor moves every 1000 ms
|
||||
pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
|
||||
tractor_next_moves = []
|
||||
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
|
||||
|
||||
veggies = dict()
|
||||
veggies_debug = dict()
|
||||
|
||||
while running:
|
||||
clock.tick(60) # manual fps control not to overwork the computer
|
||||
for event in pygame.event.get():
|
||||
@ -109,6 +113,20 @@ class Game:
|
||||
#bandaid to know about stones
|
||||
tractor_next_moves = astar_search_object.astarsearch(
|
||||
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
|
||||
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1]))]
|
||||
if(current_veggie in veggies_debug):
|
||||
veggies_debug[current_veggie]+=1
|
||||
else:
|
||||
veggies_debug[current_veggie] = 1
|
||||
|
||||
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2]))]
|
||||
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
|
||||
if predicted_veggie in veggies:
|
||||
veggies[predicted_veggie]+=1
|
||||
else:
|
||||
veggies[predicted_veggie] = 1
|
||||
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
|
||||
|
||||
else:
|
||||
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
|
||||
elif event.type == QUIT:
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import cv2
|
||||
import torchvision.transforms as transforms
|
||||
import argparse
|
||||
from model import CNNModel
|
||||
from neural_network.model import CNNModel
|
||||
# construct the argument parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input',
|
||||
@ -22,7 +22,7 @@ def main(path):
|
||||
|
||||
# initialize the model and load the trained weights
|
||||
model = CNNModel().to(device)
|
||||
checkpoint = torch.load('outputs/model.pth', map_location=device)
|
||||
checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
|
@ -4,9 +4,9 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
from model import CNNModel
|
||||
from datasets import train_loader, valid_loader
|
||||
from utils import save_model, save_plots
|
||||
from neural_network.model import CNNModel
|
||||
from neural_network.datasets import train_loader, valid_loader
|
||||
from neural_network.utils import save_model, save_plots
|
||||
|
||||
# construct the argument parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
Loading…
Reference in New Issue
Block a user