add neural network implementation to main
This commit is contained in:
parent
da73e223e3
commit
86be72ba33
20
main.py
20
main.py
@ -5,6 +5,7 @@ import land
|
|||||||
import tractor
|
import tractor
|
||||||
import blocks
|
import blocks
|
||||||
import astar_search
|
import astar_search
|
||||||
|
import neural_network.inference
|
||||||
from pygame.locals import *
|
from pygame.locals import *
|
||||||
|
|
||||||
|
|
||||||
@ -70,10 +71,13 @@ class Game:
|
|||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
move_tractor_event = pygame.USEREVENT + 1
|
move_tractor_event = pygame.USEREVENT + 1
|
||||||
pygame.time.set_timer(move_tractor_event, 100) # tractor moves every 1000 ms
|
pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
|
||||||
tractor_next_moves = []
|
tractor_next_moves = []
|
||||||
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
|
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
|
||||||
|
|
||||||
|
veggies = dict()
|
||||||
|
veggies_debug = dict()
|
||||||
|
|
||||||
while running:
|
while running:
|
||||||
clock.tick(60) # manual fps control not to overwork the computer
|
clock.tick(60) # manual fps control not to overwork the computer
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
@ -109,6 +113,20 @@ class Game:
|
|||||||
#bandaid to know about stones
|
#bandaid to know about stones
|
||||||
tractor_next_moves = astar_search_object.astarsearch(
|
tractor_next_moves = astar_search_object.astarsearch(
|
||||||
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
|
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
|
||||||
|
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1]))]
|
||||||
|
if(current_veggie in veggies_debug):
|
||||||
|
veggies_debug[current_veggie]+=1
|
||||||
|
else:
|
||||||
|
veggies_debug[current_veggie] = 1
|
||||||
|
|
||||||
|
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2]))]
|
||||||
|
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
|
||||||
|
if predicted_veggie in veggies:
|
||||||
|
veggies[predicted_veggie]+=1
|
||||||
|
else:
|
||||||
|
veggies[predicted_veggie] = 1
|
||||||
|
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
|
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
|
||||||
elif event.type == QUIT:
|
elif event.type == QUIT:
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import cv2
|
import cv2
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import argparse
|
import argparse
|
||||||
from model import CNNModel
|
from neural_network.model import CNNModel
|
||||||
# construct the argument parser
|
# construct the argument parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input',
|
parser.add_argument('-i', '--input',
|
||||||
@ -22,7 +22,7 @@ def main(path):
|
|||||||
|
|
||||||
# initialize the model and load the trained weights
|
# initialize the model and load the trained weights
|
||||||
model = CNNModel().to(device)
|
model = CNNModel().to(device)
|
||||||
checkpoint = torch.load('outputs/model.pth', map_location=device)
|
checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import time
|
import time
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from model import CNNModel
|
from neural_network.model import CNNModel
|
||||||
from datasets import train_loader, valid_loader
|
from neural_network.datasets import train_loader, valid_loader
|
||||||
from utils import save_model, save_plots
|
from neural_network.utils import save_model, save_plots
|
||||||
|
|
||||||
# construct the argument parser
|
# construct the argument parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
Loading…
Reference in New Issue
Block a user