refactor #26
21
App.py
21
App.py
@ -9,6 +9,7 @@ import Osprzet
|
|||||||
import Ui
|
import Ui
|
||||||
import BFS
|
import BFS
|
||||||
import AStar
|
import AStar
|
||||||
|
import neuralnetwork
|
||||||
|
|
||||||
|
|
||||||
bfs1_flag=False
|
bfs1_flag=False
|
||||||
@ -18,7 +19,9 @@ Astar = False
|
|||||||
Astar2 = False
|
Astar2 = False
|
||||||
if bfs3_flag or Astar or Astar2:
|
if bfs3_flag or Astar or Astar2:
|
||||||
Pole.stoneFlag = True
|
Pole.stoneFlag = True
|
||||||
TreeFlag=True
|
TreeFlag=False
|
||||||
|
nnFlag=True
|
||||||
|
newModel=False
|
||||||
|
|
||||||
pygame.init()
|
pygame.init()
|
||||||
show_console=True
|
show_console=True
|
||||||
@ -29,7 +32,7 @@ image_loader=Image.Image()
|
|||||||
image_loader.load_images()
|
image_loader.load_images()
|
||||||
goalTreasure = AStar.getRandomGoalTreasure() # nie wiem czy to najlepsze miejsce, obecnie pole zawiera pole gasStation, które służy do renderowania odpowiedniego zdjęcia
|
goalTreasure = AStar.getRandomGoalTreasure() # nie wiem czy to najlepsze miejsce, obecnie pole zawiera pole gasStation, które służy do renderowania odpowiedniego zdjęcia
|
||||||
pole=Pole.Pole(screen,image_loader, goalTreasure)
|
pole=Pole.Pole(screen,image_loader, goalTreasure)
|
||||||
pole.draw_grid() #musi byc tutaj wywołane ponieważ inicjalizuje sloty do slownika
|
pole.draw_grid(nnFlag) #musi byc tutaj wywołane ponieważ inicjalizuje sloty do slownika
|
||||||
ui=Ui.Ui(screen)
|
ui=Ui.Ui(screen)
|
||||||
#Tractor creation
|
#Tractor creation
|
||||||
traktor_slot = pole.get_slot_from_cord((0, 0))
|
traktor_slot = pole.get_slot_from_cord((0, 0))
|
||||||
@ -40,7 +43,7 @@ def init_demo(): #Demo purpose
|
|||||||
old_info=""
|
old_info=""
|
||||||
traktor.draw_tractor()
|
traktor.draw_tractor()
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
pole.randomize_colors()
|
pole.randomize_colors(nnFlag)
|
||||||
traktor.draw_tractor()
|
traktor.draw_tractor()
|
||||||
start_flag=True
|
start_flag=True
|
||||||
while True:
|
while True:
|
||||||
@ -116,6 +119,18 @@ def init_demo(): #Demo purpose
|
|||||||
if(TreeFlag):
|
if(TreeFlag):
|
||||||
traktor.move_forward(pole)
|
traktor.move_forward(pole)
|
||||||
traktor.tree_move(pole)
|
traktor.tree_move(pole)
|
||||||
|
if(nnFlag):
|
||||||
|
global model
|
||||||
|
if (newModel):
|
||||||
|
print_to_console("uczenie sieci neuronowej")
|
||||||
|
model = neuralnetwork.trainNewModel()
|
||||||
|
neuralnetwork.saveModel(model)
|
||||||
|
print('model został wygenerowany')
|
||||||
|
else:
|
||||||
|
model = neuralnetwork.loadModel('model.pth')
|
||||||
|
print_to_console("model został załądowny")
|
||||||
|
testset = neuralnetwork.getDataset(False)
|
||||||
|
print(neuralnetwork.accuracy(model, testset))
|
||||||
start_flag=False
|
start_flag=False
|
||||||
# demo_move()
|
# demo_move()
|
||||||
old_info=get_info(old_info)
|
old_info=get_info(old_info)
|
||||||
|
12
Image.py
12
Image.py
@ -1,6 +1,8 @@
|
|||||||
import pygame
|
import pygame
|
||||||
import displayControler as dCon
|
import displayControler as dCon
|
||||||
import random
|
import random
|
||||||
|
import neuralnetwork
|
||||||
|
import os
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -53,3 +55,13 @@ class Image:
|
|||||||
|
|
||||||
def return_gasStation(self):
|
def return_gasStation(self):
|
||||||
return self.gasStation_image
|
return self.gasStation_image
|
||||||
|
|
||||||
|
def getRandomImageFromDataBase():
|
||||||
|
label = random.choice(neuralnetwork.labels)
|
||||||
|
|
||||||
|
folderPath = f"dataset/test/{label}"
|
||||||
|
files = os.listdir(folderPath)
|
||||||
|
random_image = random.choice(files)
|
||||||
|
imgPath = os.path.join(folderPath, random_image)
|
||||||
|
|
||||||
|
return pygame.image.load(imgPath), label, imgPath
|
||||||
|
8
Pole.py
8
Pole.py
@ -6,6 +6,8 @@ import time
|
|||||||
import Ui
|
import Ui
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
import neuralnetwork
|
||||||
|
import Image
|
||||||
|
|
||||||
stoneList = [(3,3), (3,4), (3,5), (3,6), (4,6), (5,6), (6,6), (7,6), (8,6), (9,6), (10,6), (11,6), (12,6), (13,6), (14,6), (15,6), (16,6), (16,7), (16,8), (16,9)]
|
stoneList = [(3,3), (3,4), (3,5), (3,6), (4,6), (5,6), (6,6), (7,6), (8,6), (9,6), (10,6), (11,6), (12,6), (13,6), (14,6), (15,6), (16,6), (16,7), (16,8), (16,9)]
|
||||||
stoneFlag = False
|
stoneFlag = False
|
||||||
@ -30,7 +32,7 @@ class Pole:
|
|||||||
return self.slot_dict
|
return self.slot_dict
|
||||||
|
|
||||||
#Draw grid and tractor (new one)
|
#Draw grid and tractor (new one)
|
||||||
def draw_grid(self):
|
def draw_grid(self, nn=False):
|
||||||
for x in range(0,dCon.NUM_X): #Draw all cubes in X axis
|
for x in range(0,dCon.NUM_X): #Draw all cubes in X axis
|
||||||
for y in range(0,dCon.NUM_Y): #Draw all cubes in Y axis
|
for y in range(0,dCon.NUM_Y): #Draw all cubes in Y axis
|
||||||
new_slot=Slot.Slot(x,y,Colors.BROWN,self.screen,self.image_loader) #Creation of empty slot
|
new_slot=Slot.Slot(x,y,Colors.BROWN,self.screen,self.image_loader) #Creation of empty slot
|
||||||
@ -48,7 +50,7 @@ class Pole:
|
|||||||
st=self.slot_dict[self.gasStation]
|
st=self.slot_dict[self.gasStation]
|
||||||
st.set_gasStation_image()
|
st.set_gasStation_image()
|
||||||
|
|
||||||
def randomize_colors(self):
|
def randomize_colors(self, nn = False):
|
||||||
pygame.display.update()
|
pygame.display.update()
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
#self.ui.render_text("Randomizing Crops")
|
#self.ui.render_text("Randomizing Crops")
|
||||||
@ -59,7 +61,7 @@ class Pole:
|
|||||||
if(coordinates==(0,0)):
|
if(coordinates==(0,0)):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
self.slot_dict[coordinates].set_random_plant()
|
self.slot_dict[coordinates].set_random_plant(nn)
|
||||||
|
|
||||||
def change_color_of_slot(self,coordinates,color): #Coordinates must be tuple (x,y) (left top slot has cord (0,0) ), color has to be from defined in Colors.py or custom in RGB value (R,G,B)
|
def change_color_of_slot(self,coordinates,color): #Coordinates must be tuple (x,y) (left top slot has cord (0,0) ), color has to be from defined in Colors.py or custom in RGB value (R,G,B)
|
||||||
self.get_slot_from_cord(coordinates).color_change(color)
|
self.get_slot_from_cord(coordinates).color_change(color)
|
||||||
|
10
Slot.py
10
Slot.py
@ -16,6 +16,8 @@ class Slot:
|
|||||||
self.field=pygame.Rect(self.x_axis*dCon.CUBE_SIZE,self.y_axis*dCon.CUBE_SIZE,dCon.CUBE_SIZE,dCon.CUBE_SIZE)
|
self.field=pygame.Rect(self.x_axis*dCon.CUBE_SIZE,self.y_axis*dCon.CUBE_SIZE,dCon.CUBE_SIZE,dCon.CUBE_SIZE)
|
||||||
self.image_loader=image_loader
|
self.image_loader=image_loader
|
||||||
self.garage_image=None
|
self.garage_image=None
|
||||||
|
self.label = None
|
||||||
|
self.imagePath = None
|
||||||
|
|
||||||
def draw(self):
|
def draw(self):
|
||||||
pygame.draw.rect(self.screen,Colors.BROWN,self.field,0) #Draw field
|
pygame.draw.rect(self.screen,Colors.BROWN,self.field,0) #Draw field
|
||||||
@ -38,9 +40,13 @@ class Slot:
|
|||||||
self.plant=color
|
self.plant=color
|
||||||
self.draw()
|
self.draw()
|
||||||
|
|
||||||
def set_random_plant(self):
|
def set_random_plant(self, nn=False):
|
||||||
|
if not nn:
|
||||||
(plant_name,self.plant_image)=self.random_plant()
|
(plant_name,self.plant_image)=self.random_plant()
|
||||||
self.plant=Roslina.Roslina(plant_name)
|
self.plant=Roslina.Roslina(plant_name)
|
||||||
|
else:
|
||||||
|
self.plant_image, self.label, self.imagePath = self.random_plant_dataset()
|
||||||
|
self.plant=Roslina.Roslina(self.label)
|
||||||
self.set_image()
|
self.set_image()
|
||||||
|
|
||||||
def set_image(self):
|
def set_image(self):
|
||||||
@ -66,6 +72,8 @@ class Slot:
|
|||||||
|
|
||||||
def random_plant(self): #Probably will not be used later only for demo purpouse
|
def random_plant(self): #Probably will not be used later only for demo purpouse
|
||||||
return self.image_loader.return_random_plant()
|
return self.image_loader.return_random_plant()
|
||||||
|
def random_plant_dataset(self):
|
||||||
|
return Image.getRandomImageFromDataBase()
|
||||||
|
|
||||||
def return_plant(self):
|
def return_plant(self):
|
||||||
return self.plant
|
return self.plant
|
||||||
|
111
neuralnetwork.py
Normal file
111
neuralnetwork.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import datasets
|
||||||
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import random
|
||||||
|
|
||||||
|
imageSize = (128, 128)
|
||||||
|
labels = ['beetroot', 'potato', 'carrot']
|
||||||
|
labels.sort()
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available () else torch.device('cpu')
|
||||||
|
|
||||||
|
def getTransformation():
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
transforms.Resize(imageSize),
|
||||||
|
Lambda(lambda x: x.flatten())])
|
||||||
|
return transform
|
||||||
|
|
||||||
|
def getDataset(train=True):
|
||||||
|
transform = getTransformation()
|
||||||
|
if (train):
|
||||||
|
trainset = datasets.ImageFolder(root='dataset/train', transform=transform)
|
||||||
|
return trainset
|
||||||
|
else:
|
||||||
|
testset = datasets.ImageFolder(root='dataset/test', transform=transform)
|
||||||
|
return testset
|
||||||
|
|
||||||
|
|
||||||
|
def train(model, dataset, n_iter=100, batch_size=256):
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
criterion = nn.NLLLoss()
|
||||||
|
dl = DataLoader(dataset, batch_size=batch_size)
|
||||||
|
model.train()
|
||||||
|
for epoch in range(n_iter):
|
||||||
|
for images, targets in dl:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
out = model(images.to(device))
|
||||||
|
loss = criterion(out, targets.to(device))
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
||||||
|
return model
|
||||||
|
|
||||||
|
def accuracy(model, dataset):
|
||||||
|
model.eval()
|
||||||
|
correct = sum([(model(images.to(device)).argmax(dim=1) == targets.to(device)).sum()
|
||||||
|
for images, targets in DataLoader(dataset, batch_size=256)])
|
||||||
|
return correct.float() / len(dataset)
|
||||||
|
|
||||||
|
def getModel():
|
||||||
|
hidden_size = 300
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(imageSize[0] * imageSize[1] * 3, hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, len(labels)),
|
||||||
|
nn.LogSoftmax(dim=-1)
|
||||||
|
).to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def saveModel(model):
|
||||||
|
torch.save(model.state_dict(), 'model.pth')
|
||||||
|
|
||||||
|
def loadModel(path):
|
||||||
|
model = getModel()
|
||||||
|
model.load_state_dict(torch.load(path))
|
||||||
|
return model
|
||||||
|
|
||||||
|
def trainNewModel(n_iter=100, batch_size=256):
|
||||||
|
trainset = getDataset(True)
|
||||||
|
model = getModel()
|
||||||
|
model = train(model, trainset)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predictLabel(imagePath, model):
|
||||||
|
image = Image.open(imagePath).convert("RGB")
|
||||||
|
image = preprocess_image(image)
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval() # Ustawienie modelu w tryb ewaluacji
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
||||||
|
predicted_class = torch.argmax(output).item()
|
||||||
|
return labels[predicted_class]
|
||||||
|
|
||||||
|
def predictLabel(image, model):
|
||||||
|
image = preprocess_image(image)
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval() # Ustawienie modelu w tryb ewaluacji
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
# Znalezienie indeksu klasy o największej wartości prawdopodobieństwa
|
||||||
|
predicted_class = torch.argmax(output).item()
|
||||||
|
return labels[predicted_class]
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image):
|
||||||
|
transform = getTransformation()
|
||||||
|
image = transform(image).unsqueeze(0) # Dodanie wymiaru batch_size
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user