nn update 3
This commit is contained in:
parent
50292376e7
commit
3c5b05a7bb
20
main.py
20
main.py
@ -4,8 +4,10 @@ import random
|
||||
import land
|
||||
import tractor
|
||||
import blocks
|
||||
import nn
|
||||
import astar_search
|
||||
from pygame.locals import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Game:
|
||||
@ -23,6 +25,8 @@ class Game:
|
||||
self.grass_body = []
|
||||
self.red_block = [] #aim block
|
||||
|
||||
#self.one_body = []
|
||||
|
||||
self.fawn_seed_body = []
|
||||
self.fawn_wheat_body = []
|
||||
|
||||
@ -59,6 +63,15 @@ class Game:
|
||||
# self.potato = blocks.Blocks(self.surface, self.cell_size)
|
||||
# self.potato.locate_soil('black earth', 6, 1, [])
|
||||
|
||||
#class_names = ['Pumpkin', 'Tomato', 'Carrot']
|
||||
|
||||
self.neural_network = nn.NNModel("neural_network/save/second_model.pth")
|
||||
|
||||
# self.pumpkin_batch = self.neural_network.input_image("resources/pampkin.png")
|
||||
# self.tomato_batch = self.neural_network.input_image("resources/tomato.png")
|
||||
# self.carrot_batch = self.neural_network.input_image("resources/carrot.png")
|
||||
|
||||
|
||||
self.tractor = tractor.Tractor(self.surface, self.cell_size)
|
||||
self.tractor.draw()
|
||||
|
||||
@ -101,9 +114,16 @@ class Game:
|
||||
random_x = random.randrange(0, self.cell_number * self.cell_size, 50)
|
||||
random_y = random.randrange(0, self.cell_number * self.cell_size, 50)
|
||||
print("Generated target: ",random_x, random_y)
|
||||
#aim-blue block
|
||||
if self.red_block:
|
||||
self.red_block.pop()
|
||||
self.red_block.append([random_x/50, random_y/50])
|
||||
|
||||
self.path_image = "resources/2.png"
|
||||
self.aim_batch = self.neural_network.input_image(self.path_image)
|
||||
self.predicate = self.neural_network.predicte(self.aim_batch)
|
||||
|
||||
|
||||
# below line should be later moved into tractor.py
|
||||
angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'}
|
||||
#bandaid to know about stones
|
||||
|
@ -35,7 +35,9 @@ def main():
|
||||
for x in ["train", "validation"]}
|
||||
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
|
||||
class_names = image_datasets["train"].classes
|
||||
print(class_names)
|
||||
num_classes = len(class_names)
|
||||
print(num_classes)
|
||||
|
||||
# Load a pre-trained ResNet model
|
||||
model = models.resnet18(pretrained=True)
|
||||
@ -47,11 +49,17 @@ def main():
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
# Load the previously trained model state
|
||||
#checkpoint = torch.load("neural_network/save/trained_model.pth")
|
||||
#model.load_state_dict(checkpoint)
|
||||
|
||||
# Train the model
|
||||
def train_model(model, criterion, optimizer, num_epochs=2):
|
||||
best_model_wts = None # Initialize the variable
|
||||
best_acc = 0.0
|
||||
|
||||
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 10)
|
||||
|
BIN
neural_network/save/first_model.pth
Normal file
BIN
neural_network/save/first_model.pth
Normal file
Binary file not shown.
BIN
neural_network/save/second_model.pth
Normal file
BIN
neural_network/save/second_model.pth
Normal file
Binary file not shown.
57
nn.py
Normal file
57
nn.py
Normal file
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
# Load the saved model
|
||||
class NNModel:
|
||||
#load model
|
||||
def __init__(self, path):
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.class_names = ['Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal',
|
||||
'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower',
|
||||
'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato']
|
||||
|
||||
self.model = models.resnet18(pretrained=False)
|
||||
self.num_classes = len(self.class_names)
|
||||
self.model.fc = nn.Linear(self.model.fc.in_features, self.num_classes)
|
||||
self.model.load_state_dict(torch.load(path)) #"neural_network/save/first_model.pth"
|
||||
|
||||
#self.model.to(self.device)
|
||||
self.model.eval()
|
||||
print(self.class_names)
|
||||
print(self.num_classes)
|
||||
|
||||
def input_image(self, path): #"resources/image.jpg"
|
||||
# Define the image transformations
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
#transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Preprocess the input image
|
||||
self.input_image = Image.open(path).convert("RGB")
|
||||
self.input_tensor = preprocess(self.input_image)
|
||||
print("Input image shape:", self.input_image.size)
|
||||
input_batch = self.input_tensor.unsqueeze(0)
|
||||
return input_batch
|
||||
|
||||
def predicte(self, input_batch):
|
||||
with torch.no_grad():
|
||||
self.input_batch = input_batch.to(self.device)
|
||||
self.output = self.model(self.input_batch)
|
||||
|
||||
print("Output shape:", self.output.shape)
|
||||
print("Number of classes:", self.num_classes)
|
||||
|
||||
# Get the predicted class probabilities and labels
|
||||
self.probabilities = torch.nn.functional.softmax(self.output[0], dim=0)
|
||||
self.predicted_class_index = torch.argmax(self.probabilities).item()
|
||||
self.predicted_class = self.class_names[self.predicted_class_index]
|
||||
|
||||
# Use the predicted class in your game logic
|
||||
print(f"The predicted class is: {self.predicted_class}")
|
BIN
resources/1.png
Normal file
BIN
resources/1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 57 KiB |
BIN
resources/2.png
Normal file
BIN
resources/2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 76 KiB |
Loading…
Reference in New Issue
Block a user