nn update 3
This commit is contained in:
parent
50292376e7
commit
3c5b05a7bb
20
main.py
20
main.py
@ -4,8 +4,10 @@ import random
|
|||||||
import land
|
import land
|
||||||
import tractor
|
import tractor
|
||||||
import blocks
|
import blocks
|
||||||
|
import nn
|
||||||
import astar_search
|
import astar_search
|
||||||
from pygame.locals import *
|
from pygame.locals import *
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
@ -23,6 +25,8 @@ class Game:
|
|||||||
self.grass_body = []
|
self.grass_body = []
|
||||||
self.red_block = [] #aim block
|
self.red_block = [] #aim block
|
||||||
|
|
||||||
|
#self.one_body = []
|
||||||
|
|
||||||
self.fawn_seed_body = []
|
self.fawn_seed_body = []
|
||||||
self.fawn_wheat_body = []
|
self.fawn_wheat_body = []
|
||||||
|
|
||||||
@ -59,6 +63,15 @@ class Game:
|
|||||||
# self.potato = blocks.Blocks(self.surface, self.cell_size)
|
# self.potato = blocks.Blocks(self.surface, self.cell_size)
|
||||||
# self.potato.locate_soil('black earth', 6, 1, [])
|
# self.potato.locate_soil('black earth', 6, 1, [])
|
||||||
|
|
||||||
|
#class_names = ['Pumpkin', 'Tomato', 'Carrot']
|
||||||
|
|
||||||
|
self.neural_network = nn.NNModel("neural_network/save/second_model.pth")
|
||||||
|
|
||||||
|
# self.pumpkin_batch = self.neural_network.input_image("resources/pampkin.png")
|
||||||
|
# self.tomato_batch = self.neural_network.input_image("resources/tomato.png")
|
||||||
|
# self.carrot_batch = self.neural_network.input_image("resources/carrot.png")
|
||||||
|
|
||||||
|
|
||||||
self.tractor = tractor.Tractor(self.surface, self.cell_size)
|
self.tractor = tractor.Tractor(self.surface, self.cell_size)
|
||||||
self.tractor.draw()
|
self.tractor.draw()
|
||||||
|
|
||||||
@ -101,9 +114,16 @@ class Game:
|
|||||||
random_x = random.randrange(0, self.cell_number * self.cell_size, 50)
|
random_x = random.randrange(0, self.cell_number * self.cell_size, 50)
|
||||||
random_y = random.randrange(0, self.cell_number * self.cell_size, 50)
|
random_y = random.randrange(0, self.cell_number * self.cell_size, 50)
|
||||||
print("Generated target: ",random_x, random_y)
|
print("Generated target: ",random_x, random_y)
|
||||||
|
#aim-blue block
|
||||||
if self.red_block:
|
if self.red_block:
|
||||||
self.red_block.pop()
|
self.red_block.pop()
|
||||||
self.red_block.append([random_x/50, random_y/50])
|
self.red_block.append([random_x/50, random_y/50])
|
||||||
|
|
||||||
|
self.path_image = "resources/2.png"
|
||||||
|
self.aim_batch = self.neural_network.input_image(self.path_image)
|
||||||
|
self.predicate = self.neural_network.predicte(self.aim_batch)
|
||||||
|
|
||||||
|
|
||||||
# below line should be later moved into tractor.py
|
# below line should be later moved into tractor.py
|
||||||
angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'}
|
angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'}
|
||||||
#bandaid to know about stones
|
#bandaid to know about stones
|
||||||
|
@ -35,7 +35,9 @@ def main():
|
|||||||
for x in ["train", "validation"]}
|
for x in ["train", "validation"]}
|
||||||
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
|
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
|
||||||
class_names = image_datasets["train"].classes
|
class_names = image_datasets["train"].classes
|
||||||
|
print(class_names)
|
||||||
num_classes = len(class_names)
|
num_classes = len(class_names)
|
||||||
|
print(num_classes)
|
||||||
|
|
||||||
# Load a pre-trained ResNet model
|
# Load a pre-trained ResNet model
|
||||||
model = models.resnet18(pretrained=True)
|
model = models.resnet18(pretrained=True)
|
||||||
@ -47,11 +49,17 @@ def main():
|
|||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
||||||
|
|
||||||
|
# Load the previously trained model state
|
||||||
|
#checkpoint = torch.load("neural_network/save/trained_model.pth")
|
||||||
|
#model.load_state_dict(checkpoint)
|
||||||
|
|
||||||
# Train the model
|
# Train the model
|
||||||
def train_model(model, criterion, optimizer, num_epochs=2):
|
def train_model(model, criterion, optimizer, num_epochs=2):
|
||||||
best_model_wts = None # Initialize the variable
|
best_model_wts = None # Initialize the variable
|
||||||
best_acc = 0.0
|
best_acc = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||||
print("-" * 10)
|
print("-" * 10)
|
||||||
|
BIN
neural_network/save/first_model.pth
Normal file
BIN
neural_network/save/first_model.pth
Normal file
Binary file not shown.
BIN
neural_network/save/second_model.pth
Normal file
BIN
neural_network/save/second_model.pth
Normal file
Binary file not shown.
57
nn.py
Normal file
57
nn.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.models as models
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Load the saved model
|
||||||
|
class NNModel:
|
||||||
|
#load model
|
||||||
|
def __init__(self, path):
|
||||||
|
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.class_names = ['Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal',
|
||||||
|
'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower',
|
||||||
|
'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato']
|
||||||
|
|
||||||
|
self.model = models.resnet18(pretrained=False)
|
||||||
|
self.num_classes = len(self.class_names)
|
||||||
|
self.model.fc = nn.Linear(self.model.fc.in_features, self.num_classes)
|
||||||
|
self.model.load_state_dict(torch.load(path)) #"neural_network/save/first_model.pth"
|
||||||
|
|
||||||
|
#self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
print(self.class_names)
|
||||||
|
print(self.num_classes)
|
||||||
|
|
||||||
|
def input_image(self, path): #"resources/image.jpg"
|
||||||
|
# Define the image transformations
|
||||||
|
preprocess = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
#transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# Preprocess the input image
|
||||||
|
self.input_image = Image.open(path).convert("RGB")
|
||||||
|
self.input_tensor = preprocess(self.input_image)
|
||||||
|
print("Input image shape:", self.input_image.size)
|
||||||
|
input_batch = self.input_tensor.unsqueeze(0)
|
||||||
|
return input_batch
|
||||||
|
|
||||||
|
def predicte(self, input_batch):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.input_batch = input_batch.to(self.device)
|
||||||
|
self.output = self.model(self.input_batch)
|
||||||
|
|
||||||
|
print("Output shape:", self.output.shape)
|
||||||
|
print("Number of classes:", self.num_classes)
|
||||||
|
|
||||||
|
# Get the predicted class probabilities and labels
|
||||||
|
self.probabilities = torch.nn.functional.softmax(self.output[0], dim=0)
|
||||||
|
self.predicted_class_index = torch.argmax(self.probabilities).item()
|
||||||
|
self.predicted_class = self.class_names[self.predicted_class_index]
|
||||||
|
|
||||||
|
# Use the predicted class in your game logic
|
||||||
|
print(f"The predicted class is: {self.predicted_class}")
|
BIN
resources/1.png
Normal file
BIN
resources/1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 57 KiB |
BIN
resources/2.png
Normal file
BIN
resources/2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 76 KiB |
Loading…
Reference in New Issue
Block a user