Compare commits
3 Commits
9e978d6032
...
9667655a2a
Author | SHA1 | Date | |
---|---|---|---|
|
9667655a2a | ||
|
b45c2e0f1f | ||
|
fb0ec5057c |
Binary file not shown.
BIN
source/NN/__pycache__/neural_network.cpython-311.pyc
Normal file
BIN
source/NN/__pycache__/neural_network.cpython-311.pyc
Normal file
Binary file not shown.
@ -3,16 +3,26 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Neural_Network_Model(nn.Module):
|
||||
def __init__(self, num_classes=5,hidden_layer1 = 100,hidden_layer2 = 100):
|
||||
super(Neural_Network_Model, self).__init__()
|
||||
self.fc1 = nn.Linear(150*150*3,hidden_layer1)
|
||||
class Conv_Neural_Network_Model(nn.Module):
|
||||
def __init__(self, num_classes=5,hidden_layer1 = 512,hidden_layer2 = 256):
|
||||
super(Conv_Neural_Network_Model, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
|
||||
self.fc1 = nn.Linear(128*12*12,hidden_layer1)
|
||||
self.fc2 = nn.Linear(hidden_layer1,hidden_layer2)
|
||||
self.out = nn.Linear(hidden_layer2,num_classes)
|
||||
# two hidden layers
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 150*150*3)
|
||||
x = self.pool1(F.relu(self.conv1(x)))
|
||||
x = self.pool2(F.relu(self.conv2(x)))
|
||||
x = self.pool3(F.relu(self.conv3(x)))
|
||||
x = x.view(-1, 128*12*12) #<----flattening the image
|
||||
x = self.fc1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.fc2(x)
|
||||
|
@ -4,7 +4,6 @@ from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms, utils
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from model import *
|
||||
from PIL import Image
|
||||
|
||||
@ -12,9 +11,9 @@ device = torch.device('cuda')
|
||||
|
||||
#data transform to tensors:
|
||||
data_transformer = transforms.Compose([
|
||||
transforms.Resize((150, 150)),
|
||||
transforms.Resize((100, 100)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
|
||||
@ -24,18 +23,13 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme
|
||||
|
||||
|
||||
#to mozna nawet przerzucic do funkcji train:
|
||||
#train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)
|
||||
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True, num_workers=2)
|
||||
|
||||
#test if classes work properly:
|
||||
#print(train_set.classes)
|
||||
#print(train_set.class_to_idx)
|
||||
#print(train_set.targets[3002])
|
||||
# train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
|
||||
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
|
||||
|
||||
|
||||
#function for training model
|
||||
def train(model, dataset, iter=100, batch_size=64):
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||
criterion = nn.NLLLoss()
|
||||
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
model.train()
|
||||
@ -62,14 +56,12 @@ def accuracy(model, dataset):
|
||||
return correct.float() / len(dataset)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model = Neural_Network_Model()
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.to(device)
|
||||
|
||||
model.load_state_dict(torch.load('model.pth'))
|
||||
model.eval()
|
||||
#loading the already saved model:
|
||||
# model.load_state_dict(torch.load('model.pth'))
|
||||
# model.eval()
|
||||
|
||||
#training the model:
|
||||
# train(model, train_set)
|
||||
@ -77,19 +69,46 @@ model.eval()
|
||||
# torch.save(model.state_dict(), 'model.pth')
|
||||
|
||||
|
||||
|
||||
def load_model():
|
||||
model = Conv_Neural_Network_Model()
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
testImage = Image.open(image_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
return testImage
|
||||
|
||||
def guess_image(model, image_tensor):
|
||||
with torch.no_grad():
|
||||
testOutput = model(image_tensor)
|
||||
_, predicted = torch.max(testOutput, 1)
|
||||
predicted_class = train_set.classes[predicted.item()]
|
||||
return predicted_class
|
||||
|
||||
|
||||
|
||||
# image_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
|
||||
# image_tensor = load_image(image_path)
|
||||
# prediction = guess_image(load_model(), image_tensor)
|
||||
# print(f"The predicted image is: {prediction}")
|
||||
|
||||
#TEST - loading the image and getting results:
|
||||
testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg'
|
||||
testImage_path = 'resources/images/plant_photos/pexels-justus-menke-3490295-5213970.jpg'
|
||||
testImage = Image.open(testImage_path)
|
||||
testImage = data_transformer(testImage)
|
||||
testImage = testImage.unsqueeze(0)
|
||||
testImage = testImage.to(device)
|
||||
|
||||
model.load_state_dict(torch.load('model.pth'))
|
||||
model.load_state_dict(torch.load('CNN_model.pth'))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
testOutput = model(testImage)
|
||||
_, predicted = torch.max(testOutput, 1)
|
||||
predicted_class = train_set.classes[predicted.item()]
|
||||
print(f'The predicted class is: {predicted_class}')
|
||||
|
||||
print(f'The predicted class is: {predicted_class}')
|
BIN
source/__pycache__/astar.cpython-311.pyc
Normal file
BIN
source/__pycache__/astar.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -52,4 +52,11 @@ def get_tile_coordinates(index):
|
||||
tile = tiles[index]
|
||||
return tile.x, tile.y
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_tile_index():
|
||||
valid_indices = []
|
||||
for index, tile in enumerate(tiles):
|
||||
if tile.image=="resources/images/sampling.png":
|
||||
valid_indices.append(index)
|
||||
return random.choice(valid_indices)
|
@ -3,16 +3,17 @@ import time
|
||||
import random
|
||||
import pandas as pd
|
||||
import joblib
|
||||
|
||||
from area.constants import WIDTH, HEIGHT, TILE_SIZE
|
||||
from area.field import drawWindow
|
||||
from area.tractor import Tractor, do_actions
|
||||
from area.field import tiles, fieldX, fieldY
|
||||
from area.field import get_tile_coordinates
|
||||
from area.field import get_tile_coordinates, get_tile_index
|
||||
from ground import Dirt
|
||||
from plant import Plant
|
||||
from bfs import graphsearch, Istate, succ
|
||||
from astar import a_star
|
||||
from NN.neural_network import load_model, load_image, guess_image
|
||||
|
||||
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||
pygame.display.set_caption('Intelligent tractor')
|
||||
|
||||
@ -23,7 +24,7 @@ def main():
|
||||
pygame.display.update()
|
||||
|
||||
#getting coordinates of our "goal tile":
|
||||
tile_index=127
|
||||
tile_index = get_tile_index()
|
||||
tile_x, tile_y = get_tile_coordinates(tile_index)
|
||||
if tile_x is not None and tile_y is not None:
|
||||
print(f"Coordinates of tile {tile_index} are: ({tile_x}, {tile_y})")
|
||||
@ -128,6 +129,13 @@ def main():
|
||||
print(predykcje)
|
||||
if predykcje == 'work':
|
||||
tractor.work_on_field(tile1, d1, p1)
|
||||
|
||||
#guessing the image under the tile:
|
||||
tiles[tile_index].display_photo()
|
||||
image_path = tiles[tile_index].photo
|
||||
image_tensor = load_image(image_path)
|
||||
prediction = guess_image(load_model(), image_tensor)
|
||||
print(f"The predicted image is: {prediction}")
|
||||
time.sleep(30)
|
||||
print("\n")
|
||||
|
||||
|
BIN
source/model.pth
BIN
source/model.pth
Binary file not shown.
@ -1,5 +1,8 @@
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
# path to plant images folder (used in randomize_photo function)
|
||||
folder_path = "resources/images/plant_photos"
|
||||
@ -48,4 +51,9 @@ class Tile:
|
||||
self.image = "resources/images/rock_dirt.png"
|
||||
|
||||
|
||||
# DISCLAMER check column and choose plant type ("potato","wheat" etc.)
|
||||
def display_photo(self):
|
||||
image_path = self.photo
|
||||
img = Image.open(image_path)
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
Loading…
Reference in New Issue
Block a user