Compare commits
3 Commits
master
...
neural_net
Author | SHA1 | Date | |
---|---|---|---|
5b2a499631 | |||
3b2342a6b4 | |||
ebcecf4279 |
BIN
Tiles/Base.jpg
Normal file
After Width: | Height: | Size: 209 KiB |
BIN
Tiles/Bend.jpg
Normal file
After Width: | Height: | Size: 192 KiB |
BIN
Tiles/End.jpg
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
Tiles/Intersection.jpg
Normal file
After Width: | Height: | Size: 187 KiB |
BIN
Tiles/Junction.jpg
Normal file
After Width: | Height: | Size: 178 KiB |
BIN
Tiles/Straight.jpg
Normal file
After Width: | Height: | Size: 186 KiB |
Before Width: | Height: | Size: 9.3 KiB After Width: | Height: | Size: 9.3 KiB |
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 9.8 KiB After Width: | Height: | Size: 9.8 KiB |
22
collect
@ -24,7 +24,7 @@ edge [fontname="helvetica"] ;
|
||||
6 -> 10 ;
|
||||
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
||||
10 -> 11 ;
|
||||
12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
12 [label="garbage_type <= 2.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
11 -> 12 ;
|
||||
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
12 -> 13 ;
|
||||
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
|
||||
15 -> 16 ;
|
||||
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
||||
15 -> 17 ;
|
||||
18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||
17 -> 18 ;
|
||||
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
18 -> 19 ;
|
||||
@ -54,11 +54,11 @@ edge [fontname="helvetica"] ;
|
||||
23 -> 25 ;
|
||||
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
||||
25 -> 26 ;
|
||||
27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
27 [label="days_since_last_collection <= 22.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
25 -> 27 ;
|
||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
27 -> 28 ;
|
||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
27 -> 29 ;
|
||||
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
||||
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||
@ -88,18 +88,14 @@ edge [fontname="helvetica"] ;
|
||||
40 -> 42 ;
|
||||
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
||||
42 -> 43 ;
|
||||
44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||
42 -> 44 ;
|
||||
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
44 -> 45 ;
|
||||
46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||
44 -> 46 ;
|
||||
47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
46 -> 47 ;
|
||||
48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
|
||||
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||
46 -> 48 ;
|
||||
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
48 -> 49 ;
|
||||
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||
48 -> 50 ;
|
||||
}
|
||||
|
BIN
collect.pdf
@ -1,11 +1,16 @@
|
||||
from heuristicfn import heuristicfn
|
||||
|
||||
|
||||
FIELDWIDTH = 50
|
||||
TURN_FUEL_COST = 10
|
||||
MOVE_FUEL_COST = 200
|
||||
MAX_FUEL = 20000
|
||||
MAX_SPACE = 5
|
||||
MAX_WEIGHT = 200
|
||||
MAX_WEIGHT = 400
|
||||
MAX_WEIGHT_GLASS = 100
|
||||
MAX_WEIGHT_MIXED = 100
|
||||
MAX_WEIGHT_PAPER = 100
|
||||
MAX_WEIGHT_PLASTIC = 100
|
||||
|
||||
|
||||
class GarbageTruck:
|
||||
@ -18,6 +23,10 @@ class GarbageTruck:
|
||||
self.fuel = MAX_FUEL
|
||||
self.free_space = MAX_SPACE
|
||||
self.weight_capacity = MAX_WEIGHT
|
||||
self.weight_capacity_glass = MAX_WEIGHT_GLASS
|
||||
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
|
||||
self.weight_capacity_paper = MAX_WEIGHT_PAPER
|
||||
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
|
||||
self.rect = rect
|
||||
self.orientation = orientation
|
||||
self.request_list = request_list #lista domów do odwiedzenia
|
||||
@ -78,10 +87,33 @@ class GarbageTruck:
|
||||
|
||||
|
||||
|
||||
def collect(self):
|
||||
def collect(self, garbage_type):
|
||||
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
|
||||
self.fuel = MAX_FUEL
|
||||
self.free_space = MAX_SPACE
|
||||
self.weight_capacity = MAX_WEIGHT
|
||||
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
|
||||
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
|
||||
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
|
||||
self.weight_capacity_glass = MAX_WEIGHT_GLASS
|
||||
self.weight_capacity_paper = MAX_WEIGHT_PAPER
|
||||
request = self.request_list[0]
|
||||
if garbage_type == "glass":
|
||||
if request.weight > self.weight_capacity_glass:
|
||||
return 1
|
||||
self.weight_capacity_glass -= request.weight
|
||||
elif garbage_type == "mixed":
|
||||
if request.weight > self.weight_capacity_mixed:
|
||||
return 1
|
||||
self.weight_capacity_mixed -= request.weight
|
||||
elif garbage_type == "paper":
|
||||
if request.weight > self.weight_capacity_paper:
|
||||
return 1
|
||||
self.weight_capacity_paper -= request.weight
|
||||
elif garbage_type == "plastic":
|
||||
if request.weight > self.weight_capacity_plastic:
|
||||
return 1
|
||||
self.weight_capacity_plastic -= request.weight
|
||||
|
||||
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}, glass_capacity: {self.weight_capacity_glass}, mixed_capacity: {self.weight_capacity_mixed}, paper_capacity: {self.weight_capacity_paper}, plastic_capacity: {self.weight_capacity_plastic}')
|
||||
return 0
|
||||
pass
|
@ -1,3 +1,2 @@
|
||||
def heuristicfn(startx, starty, goalx, goaly):
|
||||
return abs(startx - goalx) + abs(starty - goaly)
|
||||
# return pow(((startx//50)-(starty//50)),2) + pow(((goalx//50)-(goaly//50)),2)
|
44
loadmodel.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import PIL.Image as Image
|
||||
import os
|
||||
|
||||
|
||||
def classify(image_path):
|
||||
model = torch.load('./model_training/garbage_model.pth')
|
||||
mean = [0.6908, 0.6612, 0.6218]
|
||||
std = [0.1947, 0.1926, 0.2086]
|
||||
classes = [
|
||||
"glass",
|
||||
"mixed",
|
||||
"paper",
|
||||
"plastic",
|
||||
]
|
||||
image_transforms = transforms.Compose([
|
||||
transforms.Resize((128, 128)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
||||
])
|
||||
|
||||
model = model.eval()
|
||||
image = Image.open(image_path)
|
||||
image = image_transforms(image).float()
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
output = model(image)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
label = os.path.basename(os.path.dirname(image_path))
|
||||
prediction = classes[predicted.item()]
|
||||
print(f"predicted: {prediction}")
|
||||
if label == prediction:
|
||||
print("predicted correctly.")
|
||||
else:
|
||||
print("predicted incorrectly.")
|
||||
return prediction
|
||||
|
||||
|
||||
# classify("./model_training/test.jpg")
|
||||
|
||||
|
61
main.py
@ -1,7 +1,6 @@
|
||||
import pygame
|
||||
from treelearn import treelearn
|
||||
|
||||
|
||||
import loadmodel
|
||||
from astar import astar
|
||||
from state import State
|
||||
import time
|
||||
@ -9,6 +8,7 @@ from garbage_truck import GarbageTruck
|
||||
from heuristicfn import heuristicfn
|
||||
from map import randomize_map
|
||||
|
||||
|
||||
pygame.init()
|
||||
WIDTH, HEIGHT = 800, 800
|
||||
window = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||
@ -18,14 +18,18 @@ AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
|
||||
FPS = 10
|
||||
FIELDCOUNT = 16
|
||||
FIELDWIDTH = 50
|
||||
BASE_IMG = pygame.image.load("Tiles/Base.jpg")
|
||||
BASE = pygame.transform.scale(BASE_IMG, (50, 50))
|
||||
|
||||
GRASS_IMG = pygame.image.load("grass.png")
|
||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||
def draw_window(agent, fields, flip):
|
||||
def draw_window(agent, fields, flip, turn):
|
||||
if flip:
|
||||
direction = pygame.transform.flip(AGENT, True, False)
|
||||
if turn:
|
||||
direction = pygame.transform.rotate(AGENT, -90)
|
||||
else:
|
||||
direction = pygame.transform.flip(AGENT, False, False)
|
||||
if turn:
|
||||
direction = pygame.transform.rotate(AGENT, 90)
|
||||
for i in range(16):
|
||||
for j in range(16):
|
||||
window.blit(fields[i][j], (i * 50, j * 50))
|
||||
@ -37,40 +41,63 @@ def main():
|
||||
clf = treelearn()
|
||||
clock = pygame.time.Clock()
|
||||
run = True
|
||||
fields, priority_array, request_list = randomize_map()
|
||||
fields, priority_array, request_list, imgpath_array = randomize_map()
|
||||
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
|
||||
low_space = 0
|
||||
while run:
|
||||
clock.tick(FPS)
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
run = False
|
||||
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
|
||||
draw_window(agent, fields, False, False) # false = kierunek east (domyslny), true = west
|
||||
x, y = agent.next_destination()
|
||||
if x == agent.rect.x and y == agent.rect.y:
|
||||
print('out of jobs')
|
||||
break
|
||||
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[agent.rect.x//50][agent.rect.y//50], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
||||
if low_space == 1:
|
||||
x, y = 0, 0
|
||||
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation,
|
||||
priority_array[agent.rect.x//50][agent.rect.y//50],
|
||||
heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
||||
for interm in steps:
|
||||
if interm.action == 'LEFT':
|
||||
agent.turn_left()
|
||||
draw_window(agent, fields, True)
|
||||
if agent.orientation == 0:
|
||||
draw_window(agent, fields, False, False)
|
||||
elif agent.orientation == 2:
|
||||
draw_window(agent, fields, True, False)
|
||||
elif agent.orientation == 1:
|
||||
draw_window(agent, fields, True, True)
|
||||
else:
|
||||
draw_window(agent, fields, False, True)
|
||||
elif interm.action == 'RIGHT':
|
||||
agent.turn_right()
|
||||
draw_window(agent, fields, False)
|
||||
if agent.orientation == 0:
|
||||
draw_window(agent, fields, False, False)
|
||||
elif agent.orientation == 2:
|
||||
draw_window(agent, fields, True, False)
|
||||
elif agent.orientation == 1:
|
||||
draw_window(agent, fields, True, True)
|
||||
else:
|
||||
draw_window(agent, fields, False, True)
|
||||
elif interm.action == 'FORWARD':
|
||||
agent.forward()
|
||||
if agent.orientation == 0:
|
||||
draw_window(agent, fields, False)
|
||||
draw_window(agent, fields, False, False)
|
||||
elif agent.orientation == 2:
|
||||
draw_window(agent, fields, True)
|
||||
draw_window(agent, fields, True, False)
|
||||
elif agent.orientation == 1:
|
||||
draw_window(agent, fields, True, True)
|
||||
else:
|
||||
draw_window(agent, fields, False)
|
||||
draw_window(agent, fields, False, True)
|
||||
time.sleep(0.3)
|
||||
agent.collect()
|
||||
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
|
||||
priority_array[agent.rect.x//50][agent.rect.y//50] = 1
|
||||
time.sleep(0.5)
|
||||
if (agent.rect.x // 50 != 0) or (agent.rect.y // 50 != 0):
|
||||
garbage_type = loadmodel.classify(imgpath_array[agent.rect.x // 50][agent.rect.y // 50])
|
||||
low_space = agent.collect(garbage_type)
|
||||
|
||||
fields[agent.rect.x//50][agent.rect.y//50] = BASE
|
||||
priority_array[agent.rect.x//50][agent.rect.y//50] = 100
|
||||
time.sleep(0.5)
|
||||
|
||||
pygame.quit()
|
||||
|
||||
|
115
map.py
@ -1,30 +1,113 @@
|
||||
import pygame, random
|
||||
import pygame as pg
|
||||
import random
|
||||
from request import Request
|
||||
|
||||
DIRT_IMG = pygame.image.load("dirt.jpg")
|
||||
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
|
||||
GRASS_IMG = pygame.image.load("grass.png")
|
||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||
SAND_IMG = pygame.image.load("sand.jpeg")
|
||||
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
|
||||
COBBLE_IMG = pygame.image.load("cobble.jpeg")
|
||||
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
|
||||
|
||||
STRAIGHT_IMG = pg.image.load("Tiles/Straight.jpg")
|
||||
STRAIGHT_VERTICAL = pg.transform.scale(STRAIGHT_IMG, (50, 50))
|
||||
STRAIGHT_HORIZONTAL = pg.transform.rotate(STRAIGHT_VERTICAL, 270)
|
||||
BASE_IMG = pg.image.load("Tiles/Base.jpg")
|
||||
BASE = pg.transform.scale(BASE_IMG, (50, 50))
|
||||
BEND_IMG = pg.image.load("Tiles/Bend.jpg")
|
||||
BEND1 = pg.transform.scale(BEND_IMG, (50, 50))
|
||||
BEND2 = pg.transform.rotate(BEND1, 90)
|
||||
BEND3 = pg.transform.rotate(pg.transform.flip(pg.transform.rotate(BEND1, 180), True, True), 180)
|
||||
BEND4 = pg.transform.rotate(BEND1, -90)
|
||||
INTERSECTION_IMG = pg.image.load("Tiles/Intersection.jpg")
|
||||
INTERSECTION = pg.transform.scale(INTERSECTION_IMG, (50, 50))
|
||||
JUNCTION_IMG = pg.image.load("Tiles/Junction.jpg")
|
||||
JUNCTION_SOUTH = pg.transform.scale(JUNCTION_IMG, (50, 50))
|
||||
JUNCTION_NORTH = pg.transform.rotate(pg.transform.flip(JUNCTION_SOUTH, True, False), 180)
|
||||
JUNCTION_EAST = pg.transform.rotate(JUNCTION_SOUTH, -90)
|
||||
JUNCTION_WEST = pg.transform.rotate(JUNCTION_SOUTH, 90)
|
||||
END_IMG = pg.image.load("Tiles/End.jpg")
|
||||
END1 = pg.transform.flip(pg.transform.rotate(pg.transform.scale(END_IMG, (50, 50)), 180), False, True)
|
||||
END2 = pg.transform.rotate(END1, 90)
|
||||
DIRT_IMG = pg.image.load("Tiles/dirt.jpg")
|
||||
DIRT = pg.transform.scale(DIRT_IMG, (50, 50))
|
||||
GRASS_IMG = pg.image.load("Tiles/grass.png")
|
||||
GRASS = pg.transform.scale(GRASS_IMG, (50, 50))
|
||||
SAND_IMG = pg.image.load("Tiles/sand.jpeg")
|
||||
SAND = pg.transform.scale(SAND_IMG, (50, 50))
|
||||
COBBLE_IMG = pg.image.load("Tiles/cobble.jpeg")
|
||||
COBBLE = pg.transform.scale(COBBLE_IMG, (50, 50))
|
||||
|
||||
|
||||
def randomize_map(): # tworzenie mapy z losowymi polami
|
||||
request_list = []
|
||||
field_array_1 = []
|
||||
field_array_2 = []
|
||||
imgpath_array = [[0 for x in range(16)] for x in range(16)]
|
||||
field_priority = []
|
||||
map_array = [['b', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'b3', 'g'],
|
||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||
['js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
|
||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||
['sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||
['sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'jn', 'g'],
|
||||
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||
['b1', 'sh', 'jw', 'sh', 'sh', 'jn', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||
['g', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'jn', 'g'],
|
||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||
['gr', 'g', 'js', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
|
||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', ' g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||
['gr', 'g', 'b1', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'b4', 'g'],
|
||||
]
|
||||
|
||||
for i in range(16):
|
||||
temp_priority = []
|
||||
for j in range(16):
|
||||
if i in (0, 1) and j in (0, 1):
|
||||
field_array_2.append(GRASS)
|
||||
if map_array[i][j] == 'b':
|
||||
field_array_2.append(BASE)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'b3':
|
||||
field_array_2.append(BEND3)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'b4':
|
||||
field_array_2.append(BEND4)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'b1':
|
||||
field_array_2.append(BEND1)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'sh':
|
||||
field_array_2.append(STRAIGHT_VERTICAL)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'sv':
|
||||
field_array_2.append(STRAIGHT_HORIZONTAL)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'i':
|
||||
field_array_2.append(INTERSECTION)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'je':
|
||||
field_array_2.append(JUNCTION_EAST)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'jw':
|
||||
field_array_2.append(JUNCTION_WEST)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'js':
|
||||
field_array_2.append(JUNCTION_SOUTH)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'jn':
|
||||
field_array_2.append(JUNCTION_NORTH)
|
||||
temp_priority.append(1)
|
||||
elif map_array[i][j] == 'gr':
|
||||
field_array_2.append(BASE)
|
||||
temp_priority.append(1000)
|
||||
else:
|
||||
prob = random.uniform(0, 100)
|
||||
if 0 <= prob <= 12:
|
||||
field_array_2.append(COBBLE)
|
||||
if 0 <= prob <= 20:
|
||||
garbage_type = random.choice(['glass', 'mixed', 'paper', 'plastic'])
|
||||
garbage_image_number = random.randrange(1, 100)
|
||||
GARBAGE_IMG = pg.image.load(
|
||||
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
|
||||
GARBAGE = pg.transform.scale(GARBAGE_IMG, (50, 50))
|
||||
field_array_2.append(GARBAGE)
|
||||
imgpath_array[i][j] = (
|
||||
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
|
||||
|
||||
temp_priority.append(100)
|
||||
request_list.append(Request(
|
||||
i * 50, j * 50, # lokacja
|
||||
@ -36,9 +119,9 @@ def randomize_map(): # tworzenie mapy z losowymi polami
|
||||
random.random() * 50 # waga śmieci
|
||||
))
|
||||
else:
|
||||
field_array_2.append(GRASS)
|
||||
temp_priority.append(1)
|
||||
field_array_2.append(BASE)
|
||||
temp_priority.append(1000)
|
||||
field_array_1.append(field_array_2)
|
||||
field_array_2 = []
|
||||
field_priority.append(temp_priority)
|
||||
return field_array_1, field_priority, request_list
|
||||
return field_array_1, field_priority, request_list, imgpath_array
|
||||
|
BIN
model_training/garbage_model.pth
Normal file
177
model_training/main.py
Normal file
@ -0,0 +1,177 @@
|
||||
import os
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset, random_split, DataLoader
|
||||
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
# input_size = 49152
|
||||
# hidden_sizes = [64, 128]
|
||||
# output_size = 10
|
||||
|
||||
classes = os.listdir('./train_dataset')
|
||||
print(classes)
|
||||
mean = [0.6908, 0.6612, 0.6218]
|
||||
std = [0.1947, 0.1926, 0.2086]
|
||||
|
||||
training_dataset_path = './train_dataset'
|
||||
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
||||
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
|
||||
|
||||
testing_dataset_path = './test_dataset'
|
||||
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
||||
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
# Mean and Standard Deviation approximations
|
||||
def get_mean_and_std(loader):
|
||||
mean = 0.
|
||||
std = 0.
|
||||
total_images_count = 0
|
||||
for images, _ in loader:
|
||||
image_count_in_a_batch = images.size(0)
|
||||
#print(images.shape)
|
||||
images = images.view(image_count_in_a_batch, images.size(1), -1)
|
||||
#print(images.shape)
|
||||
mean += images.mean(2).sum(0)
|
||||
std += images.std(2).sum(0)
|
||||
total_images_count += image_count_in_a_batch
|
||||
mean /= total_images_count
|
||||
std /= total_images_count
|
||||
return mean, std
|
||||
|
||||
print(get_mean_and_std(train_loader))
|
||||
|
||||
# Show images with applied transformations
|
||||
def show_transformed_images(dataset):
|
||||
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
|
||||
batch = next(iter(loader))
|
||||
images, labels = batch
|
||||
|
||||
grid = torchvision.utils.make_grid(images, nrow=3)
|
||||
plt.figure(figsize=(11,11))
|
||||
plt.imshow(np.transpose(grid, (1,2,0)))
|
||||
print('labels: ', labels)
|
||||
plt.show()
|
||||
|
||||
show_transformed_images(train_dataset)
|
||||
|
||||
# Neural network training:
|
||||
def set_device():
|
||||
if torch.cuda.is_available():
|
||||
dev = "cuda:0"
|
||||
else:
|
||||
dev = "cpu"
|
||||
return torch.device(dev)
|
||||
|
||||
|
||||
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
|
||||
device = set_device()
|
||||
best_acc = 0
|
||||
|
||||
for epoch in range(n_epochs):
|
||||
print("Epoch number %d " % (epoch+1))
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
running_correct = 0.0
|
||||
total = 0
|
||||
|
||||
for data in train_loader:
|
||||
images, labels = data
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
total += labels.size(0)
|
||||
|
||||
# Back propagation
|
||||
optimizer.zero_grad()
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.item()
|
||||
running_correct += (labels==predicted).sum().item()
|
||||
|
||||
epoch_loss = running_loss/len(train_loader)
|
||||
epoch_acc = 100.00 * running_correct / total
|
||||
|
||||
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
|
||||
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
|
||||
|
||||
if(test_dataset_acc > best_acc):
|
||||
best_acc = test_dataset_acc
|
||||
save_checkpoint(model, epoch, optimizer, best_acc)
|
||||
|
||||
print("Finished")
|
||||
return model
|
||||
|
||||
|
||||
def evaluate_model_on_test_set(model, test_loader):
|
||||
model.eval()
|
||||
predicted_correctly_on_epoch = 0
|
||||
total = 0
|
||||
device = set_device()
|
||||
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
images, labels = data
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
total += labels.size(0)
|
||||
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
predicted_correctly_on_epoch += (predicted == labels).sum().item()
|
||||
|
||||
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
|
||||
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
|
||||
|
||||
return epoch_acc
|
||||
|
||||
|
||||
# Saving the checkpoint:
|
||||
def save_checkpoint(model, epoch, optimizer, best_acc):
|
||||
state = {
|
||||
'epoch': epoch+1,
|
||||
'model': model.state_dict(),
|
||||
'best_accuracy': best_acc,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}
|
||||
torch.save(state, 'model_best_checkpoint.zip')
|
||||
|
||||
|
||||
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
|
||||
num_features = resnet18_model.fc.in_features
|
||||
number_of_classes = 4
|
||||
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||||
device = set_device()
|
||||
resnet_18_model = resnet18_model.to(device)
|
||||
loss_fn = nn.CrossEntropyLoss() #criterion
|
||||
|
||||
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
||||
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
|
||||
|
||||
|
||||
# Saving the model:
|
||||
checkpoint = torch.load('model_best_checkpoint.pth.zip')
|
||||
|
||||
resnet18_model = models.resnet18()
|
||||
num_features = resnet18_model.fc.in_features
|
||||
number_of_classes = 4
|
||||
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||||
resnet18_model.load_state_dict(checkpoint['model'])
|
||||
|
||||
torch.save(resnet18_model, 'garbage_model.pth')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
BIN
model_training/test.jpg
Normal file
After Width: | Height: | Size: 162 KiB |
BIN
model_training/test_dataset/glass/glass (1).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (10).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (100).jpg
Normal file
After Width: | Height: | Size: 2.0 KiB |
BIN
model_training/test_dataset/glass/glass (11).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (12).jpg
Normal file
After Width: | Height: | Size: 2.4 KiB |
BIN
model_training/test_dataset/glass/glass (13).jpg
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
model_training/test_dataset/glass/glass (14).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (15).jpg
Normal file
After Width: | Height: | Size: 2.8 KiB |
BIN
model_training/test_dataset/glass/glass (16).jpg
Normal file
After Width: | Height: | Size: 1.8 KiB |
BIN
model_training/test_dataset/glass/glass (17).jpg
Normal file
After Width: | Height: | Size: 3.8 KiB |
BIN
model_training/test_dataset/glass/glass (18).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (19).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (2).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (20).jpg
Normal file
After Width: | Height: | Size: 4.9 KiB |
BIN
model_training/test_dataset/glass/glass (21).jpg
Normal file
After Width: | Height: | Size: 3.9 KiB |
BIN
model_training/test_dataset/glass/glass (22).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (23).jpg
Normal file
After Width: | Height: | Size: 3.4 KiB |
BIN
model_training/test_dataset/glass/glass (24).jpg
Normal file
After Width: | Height: | Size: 2.2 KiB |
BIN
model_training/test_dataset/glass/glass (25).jpg
Normal file
After Width: | Height: | Size: 2.4 KiB |
BIN
model_training/test_dataset/glass/glass (26).jpg
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
model_training/test_dataset/glass/glass (27).jpg
Normal file
After Width: | Height: | Size: 3.4 KiB |
BIN
model_training/test_dataset/glass/glass (28).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (29).jpg
Normal file
After Width: | Height: | Size: 4.8 KiB |
BIN
model_training/test_dataset/glass/glass (3).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (30).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (31).jpg
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
model_training/test_dataset/glass/glass (32).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (33).jpg
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
model_training/test_dataset/glass/glass (34).jpg
Normal file
After Width: | Height: | Size: 5.4 KiB |
BIN
model_training/test_dataset/glass/glass (35).jpg
Normal file
After Width: | Height: | Size: 2.9 KiB |
BIN
model_training/test_dataset/glass/glass (36).jpg
Normal file
After Width: | Height: | Size: 2.5 KiB |
BIN
model_training/test_dataset/glass/glass (37).jpg
Normal file
After Width: | Height: | Size: 4.4 KiB |
BIN
model_training/test_dataset/glass/glass (38).jpg
Normal file
After Width: | Height: | Size: 3.7 KiB |
BIN
model_training/test_dataset/glass/glass (39).jpg
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
model_training/test_dataset/glass/glass (4).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (40).jpg
Normal file
After Width: | Height: | Size: 3.6 KiB |
BIN
model_training/test_dataset/glass/glass (41).jpg
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
model_training/test_dataset/glass/glass (42).jpg
Normal file
After Width: | Height: | Size: 3.7 KiB |
BIN
model_training/test_dataset/glass/glass (43).jpg
Normal file
After Width: | Height: | Size: 3.4 KiB |
BIN
model_training/test_dataset/glass/glass (44).jpg
Normal file
After Width: | Height: | Size: 7.2 KiB |
BIN
model_training/test_dataset/glass/glass (45).jpg
Normal file
After Width: | Height: | Size: 1.8 KiB |
BIN
model_training/test_dataset/glass/glass (46).jpg
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
model_training/test_dataset/glass/glass (47).jpg
Normal file
After Width: | Height: | Size: 5.2 KiB |
BIN
model_training/test_dataset/glass/glass (48).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (49).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (5).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (50).jpg
Normal file
After Width: | Height: | Size: 5.0 KiB |
BIN
model_training/test_dataset/glass/glass (51).jpg
Normal file
After Width: | Height: | Size: 5.3 KiB |
BIN
model_training/test_dataset/glass/glass (52).jpg
Normal file
After Width: | Height: | Size: 2.8 KiB |
BIN
model_training/test_dataset/glass/glass (53).jpg
Normal file
After Width: | Height: | Size: 6.3 KiB |
BIN
model_training/test_dataset/glass/glass (54).jpg
Normal file
After Width: | Height: | Size: 4.2 KiB |
BIN
model_training/test_dataset/glass/glass (55).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (56).jpg
Normal file
After Width: | Height: | Size: 5.9 KiB |
BIN
model_training/test_dataset/glass/glass (57).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (58).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (59).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (6).jpg
Normal file
After Width: | Height: | Size: 2.3 KiB |
BIN
model_training/test_dataset/glass/glass (60).jpg
Normal file
After Width: | Height: | Size: 5.0 KiB |
BIN
model_training/test_dataset/glass/glass (61).jpg
Normal file
After Width: | Height: | Size: 3.8 KiB |
BIN
model_training/test_dataset/glass/glass (62).jpg
Normal file
After Width: | Height: | Size: 4.4 KiB |
BIN
model_training/test_dataset/glass/glass (63).jpg
Normal file
After Width: | Height: | Size: 5.0 KiB |
BIN
model_training/test_dataset/glass/glass (64).jpg
Normal file
After Width: | Height: | Size: 4.2 KiB |
BIN
model_training/test_dataset/glass/glass (65).jpg
Normal file
After Width: | Height: | Size: 4.3 KiB |
BIN
model_training/test_dataset/glass/glass (66).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (67).jpg
Normal file
After Width: | Height: | Size: 4.9 KiB |
BIN
model_training/test_dataset/glass/glass (68).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (69).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (7).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (70).jpg
Normal file
After Width: | Height: | Size: 1.9 KiB |
BIN
model_training/test_dataset/glass/glass (71).jpg
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
model_training/test_dataset/glass/glass (72).jpg
Normal file
After Width: | Height: | Size: 3.6 KiB |
BIN
model_training/test_dataset/glass/glass (73).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (74).jpg
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
model_training/test_dataset/glass/glass (75).jpg
Normal file
After Width: | Height: | Size: 2.9 KiB |
BIN
model_training/test_dataset/glass/glass (76).jpg
Normal file
After Width: | Height: | Size: 2.2 KiB |
BIN
model_training/test_dataset/glass/glass (77).jpg
Normal file
After Width: | Height: | Size: 3.7 KiB |
BIN
model_training/test_dataset/glass/glass (78).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (79).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (8).jpg
Normal file
After Width: | Height: | Size: 2.0 KiB |
BIN
model_training/test_dataset/glass/glass (80).jpg
Normal file
After Width: | Height: | Size: 2.4 KiB |