nn update 3

This commit is contained in:
Aliaksei Brown 2023-06-05 16:18:20 +02:00
parent 50292376e7
commit 3c5b05a7bb
7 changed files with 85 additions and 0 deletions

20
main.py
View File

@ -4,8 +4,10 @@ import random
import land import land
import tractor import tractor
import blocks import blocks
import nn
import astar_search import astar_search
from pygame.locals import * from pygame.locals import *
import numpy as np
class Game: class Game:
@ -23,6 +25,8 @@ class Game:
self.grass_body = [] self.grass_body = []
self.red_block = [] #aim block self.red_block = [] #aim block
#self.one_body = []
self.fawn_seed_body = [] self.fawn_seed_body = []
self.fawn_wheat_body = [] self.fawn_wheat_body = []
@ -59,6 +63,15 @@ class Game:
# self.potato = blocks.Blocks(self.surface, self.cell_size) # self.potato = blocks.Blocks(self.surface, self.cell_size)
# self.potato.locate_soil('black earth', 6, 1, []) # self.potato.locate_soil('black earth', 6, 1, [])
#class_names = ['Pumpkin', 'Tomato', 'Carrot']
self.neural_network = nn.NNModel("neural_network/save/second_model.pth")
# self.pumpkin_batch = self.neural_network.input_image("resources/pampkin.png")
# self.tomato_batch = self.neural_network.input_image("resources/tomato.png")
# self.carrot_batch = self.neural_network.input_image("resources/carrot.png")
self.tractor = tractor.Tractor(self.surface, self.cell_size) self.tractor = tractor.Tractor(self.surface, self.cell_size)
self.tractor.draw() self.tractor.draw()
@ -101,9 +114,16 @@ class Game:
random_x = random.randrange(0, self.cell_number * self.cell_size, 50) random_x = random.randrange(0, self.cell_number * self.cell_size, 50)
random_y = random.randrange(0, self.cell_number * self.cell_size, 50) random_y = random.randrange(0, self.cell_number * self.cell_size, 50)
print("Generated target: ",random_x, random_y) print("Generated target: ",random_x, random_y)
#aim-blue block
if self.red_block: if self.red_block:
self.red_block.pop() self.red_block.pop()
self.red_block.append([random_x/50, random_y/50]) self.red_block.append([random_x/50, random_y/50])
self.path_image = "resources/2.png"
self.aim_batch = self.neural_network.input_image(self.path_image)
self.predicate = self.neural_network.predicte(self.aim_batch)
# below line should be later moved into tractor.py # below line should be later moved into tractor.py
angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'} angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'}
#bandaid to know about stones #bandaid to know about stones

View File

@ -35,7 +35,9 @@ def main():
for x in ["train", "validation"]} for x in ["train", "validation"]}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]} dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
class_names = image_datasets["train"].classes class_names = image_datasets["train"].classes
print(class_names)
num_classes = len(class_names) num_classes = len(class_names)
print(num_classes)
# Load a pre-trained ResNet model # Load a pre-trained ResNet model
model = models.resnet18(pretrained=True) model = models.resnet18(pretrained=True)
@ -47,11 +49,17 @@ def main():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Load the previously trained model state
#checkpoint = torch.load("neural_network/save/trained_model.pth")
#model.load_state_dict(checkpoint)
# Train the model # Train the model
def train_model(model, criterion, optimizer, num_epochs=2): def train_model(model, criterion, optimizer, num_epochs=2):
best_model_wts = None # Initialize the variable best_model_wts = None # Initialize the variable
best_acc = 0.0 best_acc = 0.0
for epoch in range(num_epochs): for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}") print(f"Epoch {epoch+1}/{num_epochs}")
print("-" * 10) print("-" * 10)

Binary file not shown.

Binary file not shown.

57
nn.py Normal file
View File

@ -0,0 +1,57 @@
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# Load the saved model
class NNModel:
#load model
def __init__(self, path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.class_names = ['Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal',
'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower',
'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato']
self.model = models.resnet18(pretrained=False)
self.num_classes = len(self.class_names)
self.model.fc = nn.Linear(self.model.fc.in_features, self.num_classes)
self.model.load_state_dict(torch.load(path)) #"neural_network/save/first_model.pth"
#self.model.to(self.device)
self.model.eval()
print(self.class_names)
print(self.num_classes)
def input_image(self, path): #"resources/image.jpg"
# Define the image transformations
preprocess = transforms.Compose([
transforms.Resize(224),
#transforms.CenterCrop(224),
transforms.ToTensor(),
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Preprocess the input image
self.input_image = Image.open(path).convert("RGB")
self.input_tensor = preprocess(self.input_image)
print("Input image shape:", self.input_image.size)
input_batch = self.input_tensor.unsqueeze(0)
return input_batch
def predicte(self, input_batch):
with torch.no_grad():
self.input_batch = input_batch.to(self.device)
self.output = self.model(self.input_batch)
print("Output shape:", self.output.shape)
print("Number of classes:", self.num_classes)
# Get the predicted class probabilities and labels
self.probabilities = torch.nn.functional.softmax(self.output[0], dim=0)
self.predicted_class_index = torch.argmax(self.probabilities).item()
self.predicted_class = self.class_names[self.predicted_class_index]
# Use the predicted class in your game logic
print(f"The predicted class is: {self.predicted_class}")

BIN
resources/1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
resources/2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB