Compare commits
No commits in common. "neural_network" and "master" have entirely different histories.
neural_net
...
master
BIN
Tiles/Base.jpg
Before Width: | Height: | Size: 209 KiB |
BIN
Tiles/Bend.jpg
Before Width: | Height: | Size: 192 KiB |
BIN
Tiles/End.jpg
Before Width: | Height: | Size: 193 KiB |
Before Width: | Height: | Size: 187 KiB |
Before Width: | Height: | Size: 178 KiB |
Before Width: | Height: | Size: 186 KiB |
Before Width: | Height: | Size: 9.3 KiB After Width: | Height: | Size: 9.3 KiB |
22
collect
@ -24,7 +24,7 @@ edge [fontname="helvetica"] ;
|
|||||||
6 -> 10 ;
|
6 -> 10 ;
|
||||||
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
||||||
10 -> 11 ;
|
10 -> 11 ;
|
||||||
12 [label="garbage_type <= 2.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||||
11 -> 12 ;
|
11 -> 12 ;
|
||||||
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
12 -> 13 ;
|
12 -> 13 ;
|
||||||
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
|
|||||||
15 -> 16 ;
|
15 -> 16 ;
|
||||||
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
||||||
15 -> 17 ;
|
15 -> 17 ;
|
||||||
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||||
17 -> 18 ;
|
17 -> 18 ;
|
||||||
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
18 -> 19 ;
|
18 -> 19 ;
|
||||||
@ -54,11 +54,11 @@ edge [fontname="helvetica"] ;
|
|||||||
23 -> 25 ;
|
23 -> 25 ;
|
||||||
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
||||||
25 -> 26 ;
|
25 -> 26 ;
|
||||||
27 [label="days_since_last_collection <= 22.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||||
25 -> 27 ;
|
25 -> 27 ;
|
||||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
27 -> 28 ;
|
27 -> 28 ;
|
||||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
27 -> 29 ;
|
27 -> 29 ;
|
||||||
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
||||||
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||||
@ -88,14 +88,18 @@ edge [fontname="helvetica"] ;
|
|||||||
40 -> 42 ;
|
40 -> 42 ;
|
||||||
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
||||||
42 -> 43 ;
|
42 -> 43 ;
|
||||||
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||||
42 -> 44 ;
|
42 -> 44 ;
|
||||||
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
44 -> 45 ;
|
44 -> 45 ;
|
||||||
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||||
44 -> 46 ;
|
44 -> 46 ;
|
||||||
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
46 -> 47 ;
|
46 -> 47 ;
|
||||||
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
|
||||||
46 -> 48 ;
|
46 -> 48 ;
|
||||||
|
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
|
48 -> 49 ;
|
||||||
|
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||||
|
48 -> 50 ;
|
||||||
}
|
}
|
||||||
|
BIN
collect.pdf
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
@ -1,16 +1,11 @@
|
|||||||
from heuristicfn import heuristicfn
|
from heuristicfn import heuristicfn
|
||||||
|
|
||||||
|
|
||||||
FIELDWIDTH = 50
|
FIELDWIDTH = 50
|
||||||
TURN_FUEL_COST = 10
|
TURN_FUEL_COST = 10
|
||||||
MOVE_FUEL_COST = 200
|
MOVE_FUEL_COST = 200
|
||||||
MAX_FUEL = 20000
|
MAX_FUEL = 20000
|
||||||
MAX_SPACE = 5
|
MAX_SPACE = 5
|
||||||
MAX_WEIGHT = 400
|
MAX_WEIGHT = 200
|
||||||
MAX_WEIGHT_GLASS = 100
|
|
||||||
MAX_WEIGHT_MIXED = 100
|
|
||||||
MAX_WEIGHT_PAPER = 100
|
|
||||||
MAX_WEIGHT_PLASTIC = 100
|
|
||||||
|
|
||||||
|
|
||||||
class GarbageTruck:
|
class GarbageTruck:
|
||||||
@ -23,10 +18,6 @@ class GarbageTruck:
|
|||||||
self.fuel = MAX_FUEL
|
self.fuel = MAX_FUEL
|
||||||
self.free_space = MAX_SPACE
|
self.free_space = MAX_SPACE
|
||||||
self.weight_capacity = MAX_WEIGHT
|
self.weight_capacity = MAX_WEIGHT
|
||||||
self.weight_capacity_glass = MAX_WEIGHT_GLASS
|
|
||||||
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
|
|
||||||
self.weight_capacity_paper = MAX_WEIGHT_PAPER
|
|
||||||
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
|
|
||||||
self.rect = rect
|
self.rect = rect
|
||||||
self.orientation = orientation
|
self.orientation = orientation
|
||||||
self.request_list = request_list #lista domów do odwiedzenia
|
self.request_list = request_list #lista domów do odwiedzenia
|
||||||
@ -87,33 +78,10 @@ class GarbageTruck:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def collect(self, garbage_type):
|
def collect(self):
|
||||||
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
|
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
|
||||||
self.fuel = MAX_FUEL
|
self.fuel = MAX_FUEL
|
||||||
self.free_space = MAX_SPACE
|
self.free_space = MAX_SPACE
|
||||||
self.weight_capacity = MAX_WEIGHT
|
self.weight_capacity = MAX_WEIGHT
|
||||||
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
|
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
|
||||||
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
|
|
||||||
self.weight_capacity_glass = MAX_WEIGHT_GLASS
|
|
||||||
self.weight_capacity_paper = MAX_WEIGHT_PAPER
|
|
||||||
request = self.request_list[0]
|
|
||||||
if garbage_type == "glass":
|
|
||||||
if request.weight > self.weight_capacity_glass:
|
|
||||||
return 1
|
|
||||||
self.weight_capacity_glass -= request.weight
|
|
||||||
elif garbage_type == "mixed":
|
|
||||||
if request.weight > self.weight_capacity_mixed:
|
|
||||||
return 1
|
|
||||||
self.weight_capacity_mixed -= request.weight
|
|
||||||
elif garbage_type == "paper":
|
|
||||||
if request.weight > self.weight_capacity_paper:
|
|
||||||
return 1
|
|
||||||
self.weight_capacity_paper -= request.weight
|
|
||||||
elif garbage_type == "plastic":
|
|
||||||
if request.weight > self.weight_capacity_plastic:
|
|
||||||
return 1
|
|
||||||
self.weight_capacity_plastic -= request.weight
|
|
||||||
|
|
||||||
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}, glass_capacity: {self.weight_capacity_glass}, mixed_capacity: {self.weight_capacity_mixed}, paper_capacity: {self.weight_capacity_paper}, plastic_capacity: {self.weight_capacity_plastic}')
|
|
||||||
return 0
|
|
||||||
pass
|
pass
|
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 26 KiB |
@ -1,2 +1,3 @@
|
|||||||
def heuristicfn(startx, starty, goalx, goaly):
|
def heuristicfn(startx, starty, goalx, goaly):
|
||||||
return abs(startx - goalx) + abs(starty - goaly)
|
return abs(startx - goalx) + abs(starty - goaly)
|
||||||
|
# return pow(((startx//50)-(starty//50)),2) + pow(((goalx//50)-(goaly//50)),2)
|
44
loadmodel.py
@ -1,44 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
import PIL.Image as Image
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def classify(image_path):
|
|
||||||
model = torch.load('./model_training/garbage_model.pth')
|
|
||||||
mean = [0.6908, 0.6612, 0.6218]
|
|
||||||
std = [0.1947, 0.1926, 0.2086]
|
|
||||||
classes = [
|
|
||||||
"glass",
|
|
||||||
"mixed",
|
|
||||||
"paper",
|
|
||||||
"plastic",
|
|
||||||
]
|
|
||||||
image_transforms = transforms.Compose([
|
|
||||||
transforms.Resize((128, 128)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
|
||||||
])
|
|
||||||
|
|
||||||
model = model.eval()
|
|
||||||
image = Image.open(image_path)
|
|
||||||
image = image_transforms(image).float()
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
|
|
||||||
output = model(image)
|
|
||||||
_, predicted = torch.max(output.data, 1)
|
|
||||||
|
|
||||||
label = os.path.basename(os.path.dirname(image_path))
|
|
||||||
prediction = classes[predicted.item()]
|
|
||||||
print(f"predicted: {prediction}")
|
|
||||||
if label == prediction:
|
|
||||||
print("predicted correctly.")
|
|
||||||
else:
|
|
||||||
print("predicted incorrectly.")
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
|
|
||||||
# classify("./model_training/test.jpg")
|
|
||||||
|
|
||||||
|
|
61
main.py
@ -1,6 +1,7 @@
|
|||||||
import pygame
|
import pygame
|
||||||
from treelearn import treelearn
|
from treelearn import treelearn
|
||||||
import loadmodel
|
|
||||||
|
|
||||||
from astar import astar
|
from astar import astar
|
||||||
from state import State
|
from state import State
|
||||||
import time
|
import time
|
||||||
@ -8,7 +9,6 @@ from garbage_truck import GarbageTruck
|
|||||||
from heuristicfn import heuristicfn
|
from heuristicfn import heuristicfn
|
||||||
from map import randomize_map
|
from map import randomize_map
|
||||||
|
|
||||||
|
|
||||||
pygame.init()
|
pygame.init()
|
||||||
WIDTH, HEIGHT = 800, 800
|
WIDTH, HEIGHT = 800, 800
|
||||||
window = pygame.display.set_mode((WIDTH, HEIGHT))
|
window = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||||
@ -18,18 +18,14 @@ AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
|
|||||||
FPS = 10
|
FPS = 10
|
||||||
FIELDCOUNT = 16
|
FIELDCOUNT = 16
|
||||||
FIELDWIDTH = 50
|
FIELDWIDTH = 50
|
||||||
BASE_IMG = pygame.image.load("Tiles/Base.jpg")
|
|
||||||
BASE = pygame.transform.scale(BASE_IMG, (50, 50))
|
|
||||||
|
|
||||||
def draw_window(agent, fields, flip, turn):
|
GRASS_IMG = pygame.image.load("grass.png")
|
||||||
|
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||||
|
def draw_window(agent, fields, flip):
|
||||||
if flip:
|
if flip:
|
||||||
direction = pygame.transform.flip(AGENT, True, False)
|
direction = pygame.transform.flip(AGENT, True, False)
|
||||||
if turn:
|
|
||||||
direction = pygame.transform.rotate(AGENT, -90)
|
|
||||||
else:
|
else:
|
||||||
direction = pygame.transform.flip(AGENT, False, False)
|
direction = pygame.transform.flip(AGENT, False, False)
|
||||||
if turn:
|
|
||||||
direction = pygame.transform.rotate(AGENT, 90)
|
|
||||||
for i in range(16):
|
for i in range(16):
|
||||||
for j in range(16):
|
for j in range(16):
|
||||||
window.blit(fields[i][j], (i * 50, j * 50))
|
window.blit(fields[i][j], (i * 50, j * 50))
|
||||||
@ -41,64 +37,41 @@ def main():
|
|||||||
clf = treelearn()
|
clf = treelearn()
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
run = True
|
run = True
|
||||||
fields, priority_array, request_list, imgpath_array = randomize_map()
|
fields, priority_array, request_list = randomize_map()
|
||||||
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
|
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
|
||||||
low_space = 0
|
|
||||||
while run:
|
while run:
|
||||||
clock.tick(FPS)
|
clock.tick(FPS)
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
if event.type == pygame.QUIT:
|
if event.type == pygame.QUIT:
|
||||||
run = False
|
run = False
|
||||||
draw_window(agent, fields, False, False) # false = kierunek east (domyslny), true = west
|
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
|
||||||
x, y = agent.next_destination()
|
x, y = agent.next_destination()
|
||||||
if x == agent.rect.x and y == agent.rect.y:
|
if x == agent.rect.x and y == agent.rect.y:
|
||||||
print('out of jobs')
|
print('out of jobs')
|
||||||
break
|
break
|
||||||
if low_space == 1:
|
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[agent.rect.x//50][agent.rect.y//50], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
||||||
x, y = 0, 0
|
|
||||||
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation,
|
|
||||||
priority_array[agent.rect.x//50][agent.rect.y//50],
|
|
||||||
heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
|
||||||
for interm in steps:
|
for interm in steps:
|
||||||
if interm.action == 'LEFT':
|
if interm.action == 'LEFT':
|
||||||
agent.turn_left()
|
agent.turn_left()
|
||||||
if agent.orientation == 0:
|
draw_window(agent, fields, True)
|
||||||
draw_window(agent, fields, False, False)
|
|
||||||
elif agent.orientation == 2:
|
|
||||||
draw_window(agent, fields, True, False)
|
|
||||||
elif agent.orientation == 1:
|
|
||||||
draw_window(agent, fields, True, True)
|
|
||||||
else:
|
|
||||||
draw_window(agent, fields, False, True)
|
|
||||||
elif interm.action == 'RIGHT':
|
elif interm.action == 'RIGHT':
|
||||||
agent.turn_right()
|
agent.turn_right()
|
||||||
if agent.orientation == 0:
|
draw_window(agent, fields, False)
|
||||||
draw_window(agent, fields, False, False)
|
|
||||||
elif agent.orientation == 2:
|
|
||||||
draw_window(agent, fields, True, False)
|
|
||||||
elif agent.orientation == 1:
|
|
||||||
draw_window(agent, fields, True, True)
|
|
||||||
else:
|
|
||||||
draw_window(agent, fields, False, True)
|
|
||||||
elif interm.action == 'FORWARD':
|
elif interm.action == 'FORWARD':
|
||||||
agent.forward()
|
agent.forward()
|
||||||
if agent.orientation == 0:
|
if agent.orientation == 0:
|
||||||
draw_window(agent, fields, False, False)
|
draw_window(agent, fields, False)
|
||||||
elif agent.orientation == 2:
|
elif agent.orientation == 2:
|
||||||
draw_window(agent, fields, True, False)
|
draw_window(agent, fields, True)
|
||||||
elif agent.orientation == 1:
|
|
||||||
draw_window(agent, fields, True, True)
|
|
||||||
else:
|
else:
|
||||||
draw_window(agent, fields, False, True)
|
draw_window(agent, fields, False)
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
if (agent.rect.x // 50 != 0) or (agent.rect.y // 50 != 0):
|
agent.collect()
|
||||||
garbage_type = loadmodel.classify(imgpath_array[agent.rect.x // 50][agent.rect.y // 50])
|
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
|
||||||
low_space = agent.collect(garbage_type)
|
priority_array[agent.rect.x//50][agent.rect.y//50] = 1
|
||||||
|
|
||||||
fields[agent.rect.x//50][agent.rect.y//50] = BASE
|
|
||||||
priority_array[agent.rect.x//50][agent.rect.y//50] = 100
|
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
|
|
||||||
|
115
map.py
@ -1,113 +1,30 @@
|
|||||||
import pygame as pg
|
import pygame, random
|
||||||
import random
|
|
||||||
from request import Request
|
from request import Request
|
||||||
|
|
||||||
|
DIRT_IMG = pygame.image.load("dirt.jpg")
|
||||||
STRAIGHT_IMG = pg.image.load("Tiles/Straight.jpg")
|
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
|
||||||
STRAIGHT_VERTICAL = pg.transform.scale(STRAIGHT_IMG, (50, 50))
|
GRASS_IMG = pygame.image.load("grass.png")
|
||||||
STRAIGHT_HORIZONTAL = pg.transform.rotate(STRAIGHT_VERTICAL, 270)
|
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||||
BASE_IMG = pg.image.load("Tiles/Base.jpg")
|
SAND_IMG = pygame.image.load("sand.jpeg")
|
||||||
BASE = pg.transform.scale(BASE_IMG, (50, 50))
|
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
|
||||||
BEND_IMG = pg.image.load("Tiles/Bend.jpg")
|
COBBLE_IMG = pygame.image.load("cobble.jpeg")
|
||||||
BEND1 = pg.transform.scale(BEND_IMG, (50, 50))
|
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
|
||||||
BEND2 = pg.transform.rotate(BEND1, 90)
|
|
||||||
BEND3 = pg.transform.rotate(pg.transform.flip(pg.transform.rotate(BEND1, 180), True, True), 180)
|
|
||||||
BEND4 = pg.transform.rotate(BEND1, -90)
|
|
||||||
INTERSECTION_IMG = pg.image.load("Tiles/Intersection.jpg")
|
|
||||||
INTERSECTION = pg.transform.scale(INTERSECTION_IMG, (50, 50))
|
|
||||||
JUNCTION_IMG = pg.image.load("Tiles/Junction.jpg")
|
|
||||||
JUNCTION_SOUTH = pg.transform.scale(JUNCTION_IMG, (50, 50))
|
|
||||||
JUNCTION_NORTH = pg.transform.rotate(pg.transform.flip(JUNCTION_SOUTH, True, False), 180)
|
|
||||||
JUNCTION_EAST = pg.transform.rotate(JUNCTION_SOUTH, -90)
|
|
||||||
JUNCTION_WEST = pg.transform.rotate(JUNCTION_SOUTH, 90)
|
|
||||||
END_IMG = pg.image.load("Tiles/End.jpg")
|
|
||||||
END1 = pg.transform.flip(pg.transform.rotate(pg.transform.scale(END_IMG, (50, 50)), 180), False, True)
|
|
||||||
END2 = pg.transform.rotate(END1, 90)
|
|
||||||
DIRT_IMG = pg.image.load("Tiles/dirt.jpg")
|
|
||||||
DIRT = pg.transform.scale(DIRT_IMG, (50, 50))
|
|
||||||
GRASS_IMG = pg.image.load("Tiles/grass.png")
|
|
||||||
GRASS = pg.transform.scale(GRASS_IMG, (50, 50))
|
|
||||||
SAND_IMG = pg.image.load("Tiles/sand.jpeg")
|
|
||||||
SAND = pg.transform.scale(SAND_IMG, (50, 50))
|
|
||||||
COBBLE_IMG = pg.image.load("Tiles/cobble.jpeg")
|
|
||||||
COBBLE = pg.transform.scale(COBBLE_IMG, (50, 50))
|
|
||||||
|
|
||||||
|
|
||||||
def randomize_map(): # tworzenie mapy z losowymi polami
|
def randomize_map(): # tworzenie mapy z losowymi polami
|
||||||
request_list = []
|
request_list = []
|
||||||
field_array_1 = []
|
field_array_1 = []
|
||||||
field_array_2 = []
|
field_array_2 = []
|
||||||
imgpath_array = [[0 for x in range(16)] for x in range(16)]
|
|
||||||
field_priority = []
|
field_priority = []
|
||||||
map_array = [['b', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'b3', 'g'],
|
|
||||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
|
||||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
|
||||||
['js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
|
|
||||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
|
||||||
['sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
|
||||||
['sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'jn', 'g'],
|
|
||||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
|
||||||
['b1', 'sh', 'jw', 'sh', 'sh', 'jn', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
|
||||||
['g', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
|
||||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'jn', 'g'],
|
|
||||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
|
||||||
['gr', 'g', 'js', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
|
|
||||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', ' g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
|
||||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
|
||||||
['gr', 'g', 'b1', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'b4', 'g'],
|
|
||||||
]
|
|
||||||
|
|
||||||
for i in range(16):
|
for i in range(16):
|
||||||
temp_priority = []
|
temp_priority = []
|
||||||
for j in range(16):
|
for j in range(16):
|
||||||
if map_array[i][j] == 'b':
|
if i in (0, 1) and j in (0, 1):
|
||||||
field_array_2.append(BASE)
|
field_array_2.append(GRASS)
|
||||||
temp_priority.append(1)
|
temp_priority.append(1)
|
||||||
elif map_array[i][j] == 'b3':
|
|
||||||
field_array_2.append(BEND3)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'b4':
|
|
||||||
field_array_2.append(BEND4)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'b1':
|
|
||||||
field_array_2.append(BEND1)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'sh':
|
|
||||||
field_array_2.append(STRAIGHT_VERTICAL)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'sv':
|
|
||||||
field_array_2.append(STRAIGHT_HORIZONTAL)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'i':
|
|
||||||
field_array_2.append(INTERSECTION)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'je':
|
|
||||||
field_array_2.append(JUNCTION_EAST)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'jw':
|
|
||||||
field_array_2.append(JUNCTION_WEST)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'js':
|
|
||||||
field_array_2.append(JUNCTION_SOUTH)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'jn':
|
|
||||||
field_array_2.append(JUNCTION_NORTH)
|
|
||||||
temp_priority.append(1)
|
|
||||||
elif map_array[i][j] == 'gr':
|
|
||||||
field_array_2.append(BASE)
|
|
||||||
temp_priority.append(1000)
|
|
||||||
else:
|
else:
|
||||||
prob = random.uniform(0, 100)
|
prob = random.uniform(0, 100)
|
||||||
if 0 <= prob <= 20:
|
if 0 <= prob <= 12:
|
||||||
garbage_type = random.choice(['glass', 'mixed', 'paper', 'plastic'])
|
field_array_2.append(COBBLE)
|
||||||
garbage_image_number = random.randrange(1, 100)
|
|
||||||
GARBAGE_IMG = pg.image.load(
|
|
||||||
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
|
|
||||||
GARBAGE = pg.transform.scale(GARBAGE_IMG, (50, 50))
|
|
||||||
field_array_2.append(GARBAGE)
|
|
||||||
imgpath_array[i][j] = (
|
|
||||||
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
|
|
||||||
|
|
||||||
temp_priority.append(100)
|
temp_priority.append(100)
|
||||||
request_list.append(Request(
|
request_list.append(Request(
|
||||||
i*50,j*50, #lokacja
|
i*50,j*50, #lokacja
|
||||||
@ -119,9 +36,9 @@ def randomize_map(): # tworzenie mapy z losowymi polami
|
|||||||
random.random() * 50 #waga śmieci
|
random.random() * 50 #waga śmieci
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
field_array_2.append(BASE)
|
field_array_2.append(GRASS)
|
||||||
temp_priority.append(1000)
|
temp_priority.append(1)
|
||||||
field_array_1.append(field_array_2)
|
field_array_1.append(field_array_2)
|
||||||
field_array_2 = []
|
field_array_2 = []
|
||||||
field_priority.append(temp_priority)
|
field_priority.append(temp_priority)
|
||||||
return field_array_1, field_priority, request_list, imgpath_array
|
return field_array_1, field_priority, request_list
|
@ -1,177 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from torch.utils.data import Dataset, random_split, DataLoader
|
|
||||||
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torchvision.models as models
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
def main():
|
|
||||||
torch.manual_seed(42)
|
|
||||||
# input_size = 49152
|
|
||||||
# hidden_sizes = [64, 128]
|
|
||||||
# output_size = 10
|
|
||||||
|
|
||||||
classes = os.listdir('./train_dataset')
|
|
||||||
print(classes)
|
|
||||||
mean = [0.6908, 0.6612, 0.6218]
|
|
||||||
std = [0.1947, 0.1926, 0.2086]
|
|
||||||
|
|
||||||
training_dataset_path = './train_dataset'
|
|
||||||
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
|
||||||
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
|
|
||||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
|
|
||||||
|
|
||||||
testing_dataset_path = './test_dataset'
|
|
||||||
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
|
||||||
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
|
|
||||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
|
|
||||||
|
|
||||||
# Mean and Standard Deviation approximations
|
|
||||||
def get_mean_and_std(loader):
|
|
||||||
mean = 0.
|
|
||||||
std = 0.
|
|
||||||
total_images_count = 0
|
|
||||||
for images, _ in loader:
|
|
||||||
image_count_in_a_batch = images.size(0)
|
|
||||||
#print(images.shape)
|
|
||||||
images = images.view(image_count_in_a_batch, images.size(1), -1)
|
|
||||||
#print(images.shape)
|
|
||||||
mean += images.mean(2).sum(0)
|
|
||||||
std += images.std(2).sum(0)
|
|
||||||
total_images_count += image_count_in_a_batch
|
|
||||||
mean /= total_images_count
|
|
||||||
std /= total_images_count
|
|
||||||
return mean, std
|
|
||||||
|
|
||||||
print(get_mean_and_std(train_loader))
|
|
||||||
|
|
||||||
# Show images with applied transformations
|
|
||||||
def show_transformed_images(dataset):
|
|
||||||
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
|
|
||||||
batch = next(iter(loader))
|
|
||||||
images, labels = batch
|
|
||||||
|
|
||||||
grid = torchvision.utils.make_grid(images, nrow=3)
|
|
||||||
plt.figure(figsize=(11,11))
|
|
||||||
plt.imshow(np.transpose(grid, (1,2,0)))
|
|
||||||
print('labels: ', labels)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
show_transformed_images(train_dataset)
|
|
||||||
|
|
||||||
# Neural network training:
|
|
||||||
def set_device():
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
dev = "cuda:0"
|
|
||||||
else:
|
|
||||||
dev = "cpu"
|
|
||||||
return torch.device(dev)
|
|
||||||
|
|
||||||
|
|
||||||
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
|
|
||||||
device = set_device()
|
|
||||||
best_acc = 0
|
|
||||||
|
|
||||||
for epoch in range(n_epochs):
|
|
||||||
print("Epoch number %d " % (epoch+1))
|
|
||||||
model.train()
|
|
||||||
running_loss = 0.0
|
|
||||||
running_correct = 0.0
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
for data in train_loader:
|
|
||||||
images, labels = data
|
|
||||||
images = images.to(device)
|
|
||||||
labels = labels.to(device)
|
|
||||||
total += labels.size(0)
|
|
||||||
|
|
||||||
# Back propagation
|
|
||||||
optimizer.zero_grad()
|
|
||||||
outputs = model(images)
|
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
running_loss += loss.item()
|
|
||||||
running_correct += (labels==predicted).sum().item()
|
|
||||||
|
|
||||||
epoch_loss = running_loss/len(train_loader)
|
|
||||||
epoch_acc = 100.00 * running_correct / total
|
|
||||||
|
|
||||||
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
|
|
||||||
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
|
|
||||||
|
|
||||||
if(test_dataset_acc > best_acc):
|
|
||||||
best_acc = test_dataset_acc
|
|
||||||
save_checkpoint(model, epoch, optimizer, best_acc)
|
|
||||||
|
|
||||||
print("Finished")
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model_on_test_set(model, test_loader):
|
|
||||||
model.eval()
|
|
||||||
predicted_correctly_on_epoch = 0
|
|
||||||
total = 0
|
|
||||||
device = set_device()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for data in test_loader:
|
|
||||||
images, labels = data
|
|
||||||
images = images.to(device)
|
|
||||||
labels = labels.to(device)
|
|
||||||
total += labels.size(0)
|
|
||||||
|
|
||||||
outputs = model(images)
|
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
|
||||||
predicted_correctly_on_epoch += (predicted == labels).sum().item()
|
|
||||||
|
|
||||||
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
|
|
||||||
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
|
|
||||||
|
|
||||||
return epoch_acc
|
|
||||||
|
|
||||||
|
|
||||||
# Saving the checkpoint:
|
|
||||||
def save_checkpoint(model, epoch, optimizer, best_acc):
|
|
||||||
state = {
|
|
||||||
'epoch': epoch+1,
|
|
||||||
'model': model.state_dict(),
|
|
||||||
'best_accuracy': best_acc,
|
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
}
|
|
||||||
torch.save(state, 'model_best_checkpoint.zip')
|
|
||||||
|
|
||||||
|
|
||||||
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
|
|
||||||
num_features = resnet18_model.fc.in_features
|
|
||||||
number_of_classes = 4
|
|
||||||
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
|
||||||
device = set_device()
|
|
||||||
resnet_18_model = resnet18_model.to(device)
|
|
||||||
loss_fn = nn.CrossEntropyLoss() #criterion
|
|
||||||
|
|
||||||
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
|
||||||
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
|
|
||||||
|
|
||||||
|
|
||||||
# Saving the model:
|
|
||||||
checkpoint = torch.load('model_best_checkpoint.pth.zip')
|
|
||||||
|
|
||||||
resnet18_model = models.resnet18()
|
|
||||||
num_features = resnet18_model.fc.in_features
|
|
||||||
number_of_classes = 4
|
|
||||||
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
|
||||||
resnet18_model.load_state_dict(checkpoint['model'])
|
|
||||||
|
|
||||||
torch.save(resnet18_model, 'garbage_model.pth')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Before Width: | Height: | Size: 162 KiB |
Before Width: | Height: | Size: 3.0 KiB |
Before Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 2.0 KiB |
Before Width: | Height: | Size: 3.1 KiB |
Before Width: | Height: | Size: 2.4 KiB |
Before Width: | Height: | Size: 4.0 KiB |
Before Width: | Height: | Size: 4.5 KiB |
Before Width: | Height: | Size: 2.8 KiB |
Before Width: | Height: | Size: 1.8 KiB |
Before Width: | Height: | Size: 3.8 KiB |
Before Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 3.1 KiB |
Before Width: | Height: | Size: 4.9 KiB |
Before Width: | Height: | Size: 3.9 KiB |
Before Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 3.4 KiB |
Before Width: | Height: | Size: 2.2 KiB |
Before Width: | Height: | Size: 2.4 KiB |
Before Width: | Height: | Size: 4.0 KiB |
Before Width: | Height: | Size: 3.4 KiB |
Before Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 4.8 KiB |
Before Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 4.5 KiB |
Before Width: | Height: | Size: 4.7 KiB |
Before Width: | Height: | Size: 4.5 KiB |
Before Width: | Height: | Size: 2.6 KiB |
Before Width: | Height: | Size: 5.4 KiB |
Before Width: | Height: | Size: 2.9 KiB |
Before Width: | Height: | Size: 2.5 KiB |
Before Width: | Height: | Size: 4.4 KiB |
Before Width: | Height: | Size: 3.7 KiB |
Before Width: | Height: | Size: 4.7 KiB |
Before Width: | Height: | Size: 3.0 KiB |
Before Width: | Height: | Size: 3.6 KiB |
Before Width: | Height: | Size: 4.7 KiB |
Before Width: | Height: | Size: 3.7 KiB |
Before Width: | Height: | Size: 3.4 KiB |
Before Width: | Height: | Size: 7.2 KiB |
Before Width: | Height: | Size: 1.8 KiB |
Before Width: | Height: | Size: 4.0 KiB |
Before Width: | Height: | Size: 5.2 KiB |
Before Width: | Height: | Size: 4.5 KiB |
Before Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 3.1 KiB |
Before Width: | Height: | Size: 5.0 KiB |
Before Width: | Height: | Size: 5.3 KiB |
Before Width: | Height: | Size: 2.8 KiB |
Before Width: | Height: | Size: 6.3 KiB |
Before Width: | Height: | Size: 4.2 KiB |
Before Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 5.9 KiB |
Before Width: | Height: | Size: 3.1 KiB |
Before Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 2.3 KiB |
Before Width: | Height: | Size: 5.0 KiB |
Before Width: | Height: | Size: 3.8 KiB |
Before Width: | Height: | Size: 4.4 KiB |
Before Width: | Height: | Size: 5.0 KiB |
Before Width: | Height: | Size: 4.2 KiB |
Before Width: | Height: | Size: 4.3 KiB |
Before Width: | Height: | Size: 3.0 KiB |
Before Width: | Height: | Size: 4.9 KiB |
Before Width: | Height: | Size: 3.0 KiB |
Before Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 1.9 KiB |
Before Width: | Height: | Size: 2.6 KiB |
Before Width: | Height: | Size: 3.6 KiB |
Before Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 2.6 KiB |
Before Width: | Height: | Size: 2.9 KiB |
Before Width: | Height: | Size: 2.2 KiB |
Before Width: | Height: | Size: 3.7 KiB |
Before Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 3.1 KiB |
Before Width: | Height: | Size: 2.0 KiB |
Before Width: | Height: | Size: 2.4 KiB |
Before Width: | Height: | Size: 2.2 KiB |