Compare commits
4 Commits
master
...
genetic_al
Author | SHA1 | Date | |
---|---|---|---|
2318e6ba50 | |||
|
1c26edad6c | ||
3b2342a6b4 | |||
ebcecf4279 |
BIN
Tiles/Base.jpg
Normal file
After Width: | Height: | Size: 209 KiB |
BIN
Tiles/Bend.jpg
Normal file
After Width: | Height: | Size: 192 KiB |
BIN
Tiles/End.jpg
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
Tiles/Intersection.jpg
Normal file
After Width: | Height: | Size: 187 KiB |
BIN
Tiles/Junction.jpg
Normal file
After Width: | Height: | Size: 178 KiB |
BIN
Tiles/Straight.jpg
Normal file
After Width: | Height: | Size: 186 KiB |
Before Width: | Height: | Size: 9.3 KiB After Width: | Height: | Size: 9.3 KiB |
Before Width: | Height: | Size: 3.5 KiB After Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 9.8 KiB After Width: | Height: | Size: 9.8 KiB |
28
collect
@ -24,11 +24,11 @@ edge [fontname="helvetica"] ;
|
|||||||
6 -> 10 ;
|
6 -> 10 ;
|
||||||
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
||||||
10 -> 11 ;
|
10 -> 11 ;
|
||||||
12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
12 [label="space_occupied <= 0.382\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||||
11 -> 12 ;
|
11 -> 12 ;
|
||||||
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
13 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
12 -> 13 ;
|
12 -> 13 ;
|
||||||
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
14 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
12 -> 14 ;
|
12 -> 14 ;
|
||||||
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
|
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
|
||||||
11 -> 15 ;
|
11 -> 15 ;
|
||||||
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
|
|||||||
15 -> 16 ;
|
15 -> 16 ;
|
||||||
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
||||||
15 -> 17 ;
|
15 -> 17 ;
|
||||||
18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||||
17 -> 18 ;
|
17 -> 18 ;
|
||||||
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
18 -> 19 ;
|
18 -> 19 ;
|
||||||
@ -50,15 +50,15 @@ edge [fontname="helvetica"] ;
|
|||||||
5 -> 23 ;
|
5 -> 23 ;
|
||||||
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
|
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
|
||||||
23 -> 24 ;
|
23 -> 24 ;
|
||||||
25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
|
25 [label="days_since_last_collection <= 22.0\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
|
||||||
23 -> 25 ;
|
23 -> 25 ;
|
||||||
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
||||||
25 -> 26 ;
|
25 -> 26 ;
|
||||||
27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
27 [label="odour_intensity <= 8.841\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||||
25 -> 27 ;
|
25 -> 27 ;
|
||||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
27 -> 28 ;
|
27 -> 28 ;
|
||||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
27 -> 29 ;
|
27 -> 29 ;
|
||||||
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
||||||
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||||
@ -88,18 +88,14 @@ edge [fontname="helvetica"] ;
|
|||||||
40 -> 42 ;
|
40 -> 42 ;
|
||||||
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
||||||
42 -> 43 ;
|
42 -> 43 ;
|
||||||
44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||||
42 -> 44 ;
|
42 -> 44 ;
|
||||||
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
44 -> 45 ;
|
44 -> 45 ;
|
||||||
46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||||
44 -> 46 ;
|
44 -> 46 ;
|
||||||
47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
46 -> 47 ;
|
46 -> 47 ;
|
||||||
48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
|
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||||
46 -> 48 ;
|
46 -> 48 ;
|
||||||
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
|
||||||
48 -> 49 ;
|
|
||||||
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
|
||||||
48 -> 50 ;
|
|
||||||
}
|
}
|
||||||
|
BIN
collect.pdf
@ -1,11 +1,16 @@
|
|||||||
from heuristicfn import heuristicfn
|
from heuristicfn import heuristicfn
|
||||||
|
|
||||||
|
|
||||||
FIELDWIDTH = 50
|
FIELDWIDTH = 50
|
||||||
TURN_FUEL_COST = 10
|
TURN_FUEL_COST = 10
|
||||||
MOVE_FUEL_COST = 200
|
MOVE_FUEL_COST = 200
|
||||||
MAX_FUEL = 20000
|
MAX_FUEL = 20000
|
||||||
MAX_SPACE = 5
|
MAX_SPACE = 5
|
||||||
MAX_WEIGHT = 200
|
MAX_WEIGHT = 400
|
||||||
|
MAX_WEIGHT_GLASS = 100
|
||||||
|
MAX_WEIGHT_MIXED = 100
|
||||||
|
MAX_WEIGHT_PAPER = 100
|
||||||
|
MAX_WEIGHT_PLASTIC = 100
|
||||||
|
|
||||||
|
|
||||||
class GarbageTruck:
|
class GarbageTruck:
|
||||||
@ -18,6 +23,10 @@ class GarbageTruck:
|
|||||||
self.fuel = MAX_FUEL
|
self.fuel = MAX_FUEL
|
||||||
self.free_space = MAX_SPACE
|
self.free_space = MAX_SPACE
|
||||||
self.weight_capacity = MAX_WEIGHT
|
self.weight_capacity = MAX_WEIGHT
|
||||||
|
self.weight_capacity_glass = MAX_WEIGHT_GLASS
|
||||||
|
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
|
||||||
|
self.weight_capacity_paper = MAX_WEIGHT_PAPER
|
||||||
|
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
|
||||||
self.rect = rect
|
self.rect = rect
|
||||||
self.orientation = orientation
|
self.orientation = orientation
|
||||||
self.request_list = request_list #lista domów do odwiedzenia
|
self.request_list = request_list #lista domów do odwiedzenia
|
||||||
@ -45,6 +54,8 @@ class GarbageTruck:
|
|||||||
def next_destination(self):
|
def next_destination(self):
|
||||||
|
|
||||||
for i in range(len(self.request_list)):
|
for i in range(len(self.request_list)):
|
||||||
|
if(self.request_list==[]):
|
||||||
|
break
|
||||||
request = self.request_list[i]
|
request = self.request_list[i]
|
||||||
|
|
||||||
#nie ma miejsca w zbiorniku lub za ciężkie śmieci
|
#nie ma miejsca w zbiorniku lub za ciężkie śmieci
|
||||||
@ -55,33 +66,44 @@ class GarbageTruck:
|
|||||||
if heuristicfn(request.x_pos, request.y_pos, self.dump_x, self.dump_y) // 50 * 200 > self.fuel:
|
if heuristicfn(request.x_pos, request.y_pos, self.dump_x, self.dump_y) // 50 * 200 > self.fuel:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
self.request_list.pop(i)
|
||||||
|
self.free_space -= request.volume
|
||||||
distance = heuristicfn(self.rect.x, self.rect.y, request.x_pos, request.y_pos) // 50
|
self.weight_capacity -= request.weight
|
||||||
|
return request.x_pos, request.y_pos
|
||||||
r = [
|
|
||||||
self.fuel,
|
|
||||||
distance,
|
|
||||||
request.volume,
|
|
||||||
request.last_collection,
|
|
||||||
request.is_paid,
|
|
||||||
request.odour_intensity,
|
|
||||||
request.weight,
|
|
||||||
request.type
|
|
||||||
]
|
|
||||||
if self.clf.predict([r]) == True:
|
|
||||||
self.request_list.pop(i)
|
|
||||||
self.free_space -= request.volume
|
|
||||||
self.weight_capacity -= request.weight
|
|
||||||
return request.x_pos, request.y_pos
|
|
||||||
return self.dump_x, self.dump_y
|
return self.dump_x, self.dump_y
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def collect(self):
|
def collect(self, garbage_type):
|
||||||
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
|
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
|
||||||
self.fuel = MAX_FUEL
|
self.fuel = MAX_FUEL
|
||||||
self.free_space = MAX_SPACE
|
self.free_space = MAX_SPACE
|
||||||
self.weight_capacity = MAX_WEIGHT
|
self.weight_capacity = MAX_WEIGHT
|
||||||
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
|
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
|
||||||
|
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
|
||||||
|
self.weight_capacity_glass = MAX_WEIGHT_GLASS
|
||||||
|
self.weight_capacity_paper = MAX_WEIGHT_PAPER
|
||||||
|
if self.request_list==[]:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
request = self.request_list[0]
|
||||||
|
if garbage_type == "glass":
|
||||||
|
if request.weight > self.weight_capacity_glass:
|
||||||
|
return 1
|
||||||
|
self.weight_capacity_glass -= request.weight
|
||||||
|
elif garbage_type == "mixed":
|
||||||
|
if request.weight > self.weight_capacity_mixed:
|
||||||
|
return 1
|
||||||
|
self.weight_capacity_mixed -= request.weight
|
||||||
|
elif garbage_type == "paper":
|
||||||
|
if request.weight > self.weight_capacity_paper:
|
||||||
|
return 1
|
||||||
|
self.weight_capacity_paper -= request.weight
|
||||||
|
elif garbage_type == "plastic":
|
||||||
|
if request.weight > self.weight_capacity_plastic:
|
||||||
|
return 1
|
||||||
|
self.weight_capacity_plastic -= request.weight
|
||||||
|
|
||||||
|
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}, glass_capacity: {self.weight_capacity_glass}, mixed_capacity: {self.weight_capacity_mixed}, paper_capacity: {self.weight_capacity_paper}, plastic_capacity: {self.weight_capacity_plastic}')
|
||||||
|
return 0
|
||||||
pass
|
pass
|
162
genetic.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import pygame
|
||||||
|
from treelearn import treelearn
|
||||||
|
import loadmodel
|
||||||
|
from astar import astar
|
||||||
|
from state import State
|
||||||
|
import time
|
||||||
|
from garbage_truck import GarbageTruck
|
||||||
|
from heuristicfn import heuristicfn
|
||||||
|
from map import randomize_map
|
||||||
|
from heuristicfn import heuristicfn
|
||||||
|
import pygame as pg
|
||||||
|
import random
|
||||||
|
from request import Request
|
||||||
|
|
||||||
|
def determine_fitness(requests_list):
|
||||||
|
distances = []
|
||||||
|
for i in range(len(requests_list)+1): #from: request_list[i].x_pos and .y_pos
|
||||||
|
temp = []
|
||||||
|
for j in range(len(requests_list)+1):
|
||||||
|
if j<i:
|
||||||
|
temp.append('-')
|
||||||
|
elif j==i:
|
||||||
|
temp.append(0)
|
||||||
|
elif j>i:
|
||||||
|
if i==0:
|
||||||
|
dist = heuristicfn(0, 0, requests_list[j-1].x_pos, requests_list[j-1].y_pos)
|
||||||
|
temp.append(dist)
|
||||||
|
else:
|
||||||
|
dist = heuristicfn(requests_list[i-1].x_pos, requests_list[i-1].y_pos, requests_list[j-1].x_pos, requests_list[j-1].y_pos)
|
||||||
|
temp.append(dist)
|
||||||
|
distances.append(temp)
|
||||||
|
return(distances)
|
||||||
|
|
||||||
|
|
||||||
|
def perform_permutation(obj_list, perm_list):
|
||||||
|
result = [None] * len(obj_list)
|
||||||
|
|
||||||
|
for i, index in enumerate(perm_list):
|
||||||
|
result[int(index)-1] = obj_list[i-1]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def apply_genetic(request_list):
|
||||||
|
print("Genetic algorithm started")
|
||||||
|
|
||||||
|
distances = determine_fitness(request_list)
|
||||||
|
population_size = 12
|
||||||
|
num_generations = 8
|
||||||
|
mutation_rate = 0.3
|
||||||
|
NUM = len(distances)
|
||||||
|
|
||||||
|
def initialize_population():
|
||||||
|
population = []
|
||||||
|
for _ in range(population_size):
|
||||||
|
chromosome = ['0']
|
||||||
|
while True:
|
||||||
|
if len(chromosome) == NUM:
|
||||||
|
chromosome.append('0')
|
||||||
|
break
|
||||||
|
|
||||||
|
temp = random.randint(1, NUM-1)
|
||||||
|
temp_str = str(temp)
|
||||||
|
if temp_str not in chromosome:
|
||||||
|
chromosome.append(temp_str)
|
||||||
|
population.append(chromosome)
|
||||||
|
return population
|
||||||
|
|
||||||
|
def calculate_route_length(route):
|
||||||
|
length = 0
|
||||||
|
for i in range(len(route)-1):
|
||||||
|
p = int(route[i])
|
||||||
|
q = int(route[i + 1])
|
||||||
|
length += distances[int(min(p,q))][int(max(p,q))]
|
||||||
|
return length
|
||||||
|
|
||||||
|
def calculate_fitness(population):
|
||||||
|
fitness_scores = []
|
||||||
|
for chromosome in population:
|
||||||
|
fitness_scores.append(1 / calculate_route_length(chromosome))
|
||||||
|
return fitness_scores
|
||||||
|
|
||||||
|
def parents_selection(population, fitness_scores):
|
||||||
|
selected_parents = []
|
||||||
|
for _ in range(len(population)):
|
||||||
|
candidates = random.sample(range(len(population)), 2)
|
||||||
|
fitness1 = fitness_scores[candidates[0]]
|
||||||
|
fitness2 = fitness_scores[candidates[1]]
|
||||||
|
selected_parent = population[candidates[0]] if fitness1 > fitness2 else population[candidates[1]]
|
||||||
|
selected_parents.append(selected_parent)
|
||||||
|
return selected_parents
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def david_crossover(parent1, parent2):
|
||||||
|
start_index = random.randint(1, len(parent1)-3)
|
||||||
|
end_index = random.randint(start_index+1, len(parent1)-2)
|
||||||
|
parent1_chain = parent1[start_index:end_index+1]
|
||||||
|
parent2_letters = []
|
||||||
|
for trash in parent2[1:-1]:
|
||||||
|
if trash not in parent1_chain:
|
||||||
|
parent2_letters.append(trash)
|
||||||
|
child = [parent2[0]]+parent2_letters[0:start_index] + parent1_chain + parent2_letters[start_index:]+[parent2[-1]]
|
||||||
|
""" print('PARENTS: ')
|
||||||
|
print(parent1)
|
||||||
|
print(parent2)
|
||||||
|
print('CHILDS:')
|
||||||
|
print(child) """
|
||||||
|
return child
|
||||||
|
|
||||||
|
def mutation(chromosome):
|
||||||
|
index1 = random.randint(1, len(chromosome)-2)
|
||||||
|
index2 = random.randint(1, len(chromosome)-2)
|
||||||
|
chromosome[index1], chromosome[index2] = chromosome[index2], chromosome[index1]
|
||||||
|
return chromosome
|
||||||
|
|
||||||
|
def genetic_algorithm():
|
||||||
|
population = initialize_population()
|
||||||
|
for _ in range(num_generations):
|
||||||
|
fitness_scores = calculate_fitness(population)
|
||||||
|
parents = parents_selection(population, fitness_scores)
|
||||||
|
offspring = []
|
||||||
|
for i in range(0, len(parents), 2):
|
||||||
|
parent1 = parents[i]
|
||||||
|
parent2 = parents[i+1]
|
||||||
|
child1 = david_crossover(parent1, parent2)
|
||||||
|
child2 = david_crossover(parent2, parent1)
|
||||||
|
offspring.extend([child1, child2])
|
||||||
|
population = offspring
|
||||||
|
for i in range(len(population)):
|
||||||
|
if random.random() < mutation_rate:
|
||||||
|
population[i] = mutation(population[i])
|
||||||
|
return population
|
||||||
|
|
||||||
|
|
||||||
|
best_route = None
|
||||||
|
best_length = float('inf')
|
||||||
|
population = genetic_algorithm()
|
||||||
|
for chromosome in population:
|
||||||
|
length = calculate_route_length(chromosome)
|
||||||
|
if length < best_length:
|
||||||
|
best_length = length
|
||||||
|
best_route = chromosome
|
||||||
|
|
||||||
|
print("Permutation chosen: ", best_route)
|
||||||
|
print("Its length:", best_length)
|
||||||
|
permuted_list = perform_permutation(request_list, best_route[1:-1])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,3 +1,2 @@
|
|||||||
def heuristicfn(startx, starty, goalx, goaly):
|
def heuristicfn(startx, starty, goalx, goaly):
|
||||||
return abs(startx - goalx) + abs(starty - goaly)
|
return abs(startx - goalx) + abs(starty - goaly)
|
||||||
# return pow(((startx//50)-(starty//50)),2) + pow(((goalx//50)-(goaly//50)),2)
|
|
44
loadmodel.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import PIL.Image as Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def classify(image_path):
|
||||||
|
model = torch.load('./model_training/garbage_model.pth')
|
||||||
|
mean = [0.6908, 0.6612, 0.6218]
|
||||||
|
std = [0.1947, 0.1926, 0.2086]
|
||||||
|
classes = [
|
||||||
|
"glass",
|
||||||
|
"mixed",
|
||||||
|
"paper",
|
||||||
|
"plastic",
|
||||||
|
]
|
||||||
|
image_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((128, 128)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
||||||
|
])
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
image = Image.open(image_path)
|
||||||
|
image = image_transforms(image).float()
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
output = model(image)
|
||||||
|
_, predicted = torch.max(output.data, 1)
|
||||||
|
|
||||||
|
label = os.path.basename(os.path.dirname(image_path))
|
||||||
|
prediction = classes[predicted.item()]
|
||||||
|
print(f"predicted: {prediction}")
|
||||||
|
if label == prediction:
|
||||||
|
print("predicted correctly.")
|
||||||
|
else:
|
||||||
|
print("predicted incorrectly.")
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
# classify("./model_training/test.jpg")
|
||||||
|
|
||||||
|
|
65
main.py
@ -1,13 +1,15 @@
|
|||||||
import pygame
|
import pygame
|
||||||
from treelearn import treelearn
|
from treelearn import treelearn
|
||||||
|
import loadmodel
|
||||||
|
|
||||||
from astar import astar
|
from astar import astar
|
||||||
from state import State
|
from state import State
|
||||||
import time
|
import time
|
||||||
from garbage_truck import GarbageTruck
|
from garbage_truck import GarbageTruck
|
||||||
from heuristicfn import heuristicfn
|
from heuristicfn import heuristicfn
|
||||||
from map import randomize_map
|
from map import randomize_map
|
||||||
|
from tree import apply_tree
|
||||||
|
from genetic import apply_genetic
|
||||||
|
|
||||||
|
|
||||||
pygame.init()
|
pygame.init()
|
||||||
WIDTH, HEIGHT = 800, 800
|
WIDTH, HEIGHT = 800, 800
|
||||||
@ -18,14 +20,18 @@ AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
|
|||||||
FPS = 10
|
FPS = 10
|
||||||
FIELDCOUNT = 16
|
FIELDCOUNT = 16
|
||||||
FIELDWIDTH = 50
|
FIELDWIDTH = 50
|
||||||
|
BASE_IMG = pygame.image.load("Tiles/Base.jpg")
|
||||||
|
BASE = pygame.transform.scale(BASE_IMG, (50, 50))
|
||||||
|
|
||||||
GRASS_IMG = pygame.image.load("grass.png")
|
def draw_window(agent, fields, flip, turn):
|
||||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
|
||||||
def draw_window(agent, fields, flip):
|
|
||||||
if flip:
|
if flip:
|
||||||
direction = pygame.transform.flip(AGENT, True, False)
|
direction = pygame.transform.flip(AGENT, True, False)
|
||||||
|
if turn:
|
||||||
|
direction = pygame.transform.rotate(AGENT, -90)
|
||||||
else:
|
else:
|
||||||
direction = pygame.transform.flip(AGENT, False, False)
|
direction = pygame.transform.flip(AGENT, False, False)
|
||||||
|
if turn:
|
||||||
|
direction = pygame.transform.rotate(AGENT, 90)
|
||||||
for i in range(16):
|
for i in range(16):
|
||||||
for j in range(16):
|
for j in range(16):
|
||||||
window.blit(fields[i][j], (i * 50, j * 50))
|
window.blit(fields[i][j], (i * 50, j * 50))
|
||||||
@ -37,40 +43,65 @@ def main():
|
|||||||
clf = treelearn()
|
clf = treelearn()
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
run = True
|
run = True
|
||||||
fields, priority_array, request_list = randomize_map()
|
fields, priority_array, request_list, imgpath_array = randomize_map()
|
||||||
|
apply_tree(request_list)
|
||||||
|
apply_genetic(request_list)
|
||||||
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
|
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
|
||||||
|
low_space = 0
|
||||||
while run:
|
while run:
|
||||||
clock.tick(FPS)
|
clock.tick(FPS)
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
if event.type == pygame.QUIT:
|
if event.type == pygame.QUIT:
|
||||||
run = False
|
run = False
|
||||||
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
|
draw_window(agent, fields, False, False) # false = kierunek east (domyslny), true = west
|
||||||
x, y = agent.next_destination()
|
x, y = agent.next_destination()
|
||||||
if x == agent.rect.x and y == agent.rect.y:
|
if x == agent.rect.x and y == agent.rect.y:
|
||||||
print('out of jobs')
|
print('out of jobs')
|
||||||
break
|
break
|
||||||
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[agent.rect.x//50][agent.rect.y//50], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
if low_space == 1:
|
||||||
|
x, y = 0, 0
|
||||||
|
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation,
|
||||||
|
priority_array[agent.rect.x//50][agent.rect.y//50],
|
||||||
|
heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
||||||
for interm in steps:
|
for interm in steps:
|
||||||
if interm.action == 'LEFT':
|
if interm.action == 'LEFT':
|
||||||
agent.turn_left()
|
agent.turn_left()
|
||||||
draw_window(agent, fields, True)
|
if agent.orientation == 0:
|
||||||
|
draw_window(agent, fields, False, False)
|
||||||
|
elif agent.orientation == 2:
|
||||||
|
draw_window(agent, fields, True, False)
|
||||||
|
elif agent.orientation == 1:
|
||||||
|
draw_window(agent, fields, True, True)
|
||||||
|
else:
|
||||||
|
draw_window(agent, fields, False, True)
|
||||||
elif interm.action == 'RIGHT':
|
elif interm.action == 'RIGHT':
|
||||||
agent.turn_right()
|
agent.turn_right()
|
||||||
draw_window(agent, fields, False)
|
if agent.orientation == 0:
|
||||||
|
draw_window(agent, fields, False, False)
|
||||||
|
elif agent.orientation == 2:
|
||||||
|
draw_window(agent, fields, True, False)
|
||||||
|
elif agent.orientation == 1:
|
||||||
|
draw_window(agent, fields, True, True)
|
||||||
|
else:
|
||||||
|
draw_window(agent, fields, False, True)
|
||||||
elif interm.action == 'FORWARD':
|
elif interm.action == 'FORWARD':
|
||||||
agent.forward()
|
agent.forward()
|
||||||
if agent.orientation == 0:
|
if agent.orientation == 0:
|
||||||
draw_window(agent, fields, False)
|
draw_window(agent, fields, False, False)
|
||||||
elif agent.orientation == 2:
|
elif agent.orientation == 2:
|
||||||
draw_window(agent, fields, True)
|
draw_window(agent, fields, True, False)
|
||||||
|
elif agent.orientation == 1:
|
||||||
|
draw_window(agent, fields, True, True)
|
||||||
else:
|
else:
|
||||||
draw_window(agent, fields, False)
|
draw_window(agent, fields, False, True)
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
agent.collect()
|
if (agent.rect.x // 50 != 0) or (agent.rect.y // 50 != 0):
|
||||||
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
|
garbage_type = loadmodel.classify(imgpath_array[agent.rect.x // 50][agent.rect.y // 50])
|
||||||
priority_array[agent.rect.x//50][agent.rect.y//50] = 1
|
low_space = agent.collect(garbage_type)
|
||||||
|
|
||||||
|
fields[agent.rect.x//50][agent.rect.y//50] = BASE
|
||||||
|
priority_array[agent.rect.x//50][agent.rect.y//50] = 100
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
|
131
map.py
@ -1,44 +1,127 @@
|
|||||||
import pygame, random
|
import pygame as pg
|
||||||
|
import random
|
||||||
from request import Request
|
from request import Request
|
||||||
|
|
||||||
DIRT_IMG = pygame.image.load("dirt.jpg")
|
|
||||||
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
|
|
||||||
GRASS_IMG = pygame.image.load("grass.png")
|
|
||||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
|
||||||
SAND_IMG = pygame.image.load("sand.jpeg")
|
|
||||||
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
|
|
||||||
COBBLE_IMG = pygame.image.load("cobble.jpeg")
|
|
||||||
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
|
|
||||||
|
|
||||||
def randomize_map(): # tworzenie mapy z losowymi polami
|
STRAIGHT_IMG = pg.image.load("Tiles/Straight.jpg")
|
||||||
|
STRAIGHT_VERTICAL = pg.transform.scale(STRAIGHT_IMG, (50, 50))
|
||||||
|
STRAIGHT_HORIZONTAL = pg.transform.rotate(STRAIGHT_VERTICAL, 270)
|
||||||
|
BASE_IMG = pg.image.load("Tiles/Base.jpg")
|
||||||
|
BASE = pg.transform.scale(BASE_IMG, (50, 50))
|
||||||
|
BEND_IMG = pg.image.load("Tiles/Bend.jpg")
|
||||||
|
BEND1 = pg.transform.scale(BEND_IMG, (50, 50))
|
||||||
|
BEND2 = pg.transform.rotate(BEND1, 90)
|
||||||
|
BEND3 = pg.transform.rotate(pg.transform.flip(pg.transform.rotate(BEND1, 180), True, True), 180)
|
||||||
|
BEND4 = pg.transform.rotate(BEND1, -90)
|
||||||
|
INTERSECTION_IMG = pg.image.load("Tiles/Intersection.jpg")
|
||||||
|
INTERSECTION = pg.transform.scale(INTERSECTION_IMG, (50, 50))
|
||||||
|
JUNCTION_IMG = pg.image.load("Tiles/Junction.jpg")
|
||||||
|
JUNCTION_SOUTH = pg.transform.scale(JUNCTION_IMG, (50, 50))
|
||||||
|
JUNCTION_NORTH = pg.transform.rotate(pg.transform.flip(JUNCTION_SOUTH, True, False), 180)
|
||||||
|
JUNCTION_EAST = pg.transform.rotate(JUNCTION_SOUTH, -90)
|
||||||
|
JUNCTION_WEST = pg.transform.rotate(JUNCTION_SOUTH, 90)
|
||||||
|
END_IMG = pg.image.load("Tiles/End.jpg")
|
||||||
|
END1 = pg.transform.flip(pg.transform.rotate(pg.transform.scale(END_IMG, (50, 50)), 180), False, True)
|
||||||
|
END2 = pg.transform.rotate(END1, 90)
|
||||||
|
DIRT_IMG = pg.image.load("Tiles/dirt.jpg")
|
||||||
|
DIRT = pg.transform.scale(DIRT_IMG, (50, 50))
|
||||||
|
GRASS_IMG = pg.image.load("Tiles/grass.png")
|
||||||
|
GRASS = pg.transform.scale(GRASS_IMG, (50, 50))
|
||||||
|
SAND_IMG = pg.image.load("Tiles/sand.jpeg")
|
||||||
|
SAND = pg.transform.scale(SAND_IMG, (50, 50))
|
||||||
|
COBBLE_IMG = pg.image.load("Tiles/cobble.jpeg")
|
||||||
|
COBBLE = pg.transform.scale(COBBLE_IMG, (50, 50))
|
||||||
|
|
||||||
|
|
||||||
|
def randomize_map(): # tworzenie mapy z losowymi polami
|
||||||
request_list = []
|
request_list = []
|
||||||
field_array_1 = []
|
field_array_1 = []
|
||||||
field_array_2 = []
|
field_array_2 = []
|
||||||
|
imgpath_array = [[0 for x in range(16)] for x in range(16)]
|
||||||
field_priority = []
|
field_priority = []
|
||||||
|
map_array = [['b', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'b3', 'g'],
|
||||||
|
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||||
|
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||||
|
['js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
|
||||||
|
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||||
|
['sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||||
|
['sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'jn', 'g'],
|
||||||
|
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||||
|
['b1', 'sh', 'jw', 'sh', 'sh', 'jn', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||||
|
['g', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||||
|
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'jn', 'g'],
|
||||||
|
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||||
|
['gr', 'g', 'js', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
|
||||||
|
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', ' g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
|
||||||
|
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
|
||||||
|
['gr', 'g', 'b1', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'b4', 'g'],
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(16):
|
for i in range(16):
|
||||||
temp_priority = []
|
temp_priority = []
|
||||||
for j in range(16):
|
for j in range(16):
|
||||||
if i in (0, 1) and j in (0, 1):
|
if map_array[i][j] == 'b':
|
||||||
field_array_2.append(GRASS)
|
field_array_2.append(BASE)
|
||||||
temp_priority.append(1)
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'b3':
|
||||||
|
field_array_2.append(BEND3)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'b4':
|
||||||
|
field_array_2.append(BEND4)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'b1':
|
||||||
|
field_array_2.append(BEND1)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'sh':
|
||||||
|
field_array_2.append(STRAIGHT_VERTICAL)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'sv':
|
||||||
|
field_array_2.append(STRAIGHT_HORIZONTAL)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'i':
|
||||||
|
field_array_2.append(INTERSECTION)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'je':
|
||||||
|
field_array_2.append(JUNCTION_EAST)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'jw':
|
||||||
|
field_array_2.append(JUNCTION_WEST)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'js':
|
||||||
|
field_array_2.append(JUNCTION_SOUTH)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'jn':
|
||||||
|
field_array_2.append(JUNCTION_NORTH)
|
||||||
|
temp_priority.append(1)
|
||||||
|
elif map_array[i][j] == 'gr':
|
||||||
|
field_array_2.append(BASE)
|
||||||
|
temp_priority.append(1000)
|
||||||
else:
|
else:
|
||||||
prob = random.uniform(0, 100)
|
prob = random.uniform(0, 100)
|
||||||
if 0 <= prob <= 12:
|
if 0 <= prob <= 20:
|
||||||
field_array_2.append(COBBLE)
|
garbage_type = random.choice(['glass', 'mixed', 'paper', 'plastic'])
|
||||||
|
garbage_image_number = random.randrange(1, 100)
|
||||||
|
GARBAGE_IMG = pg.image.load(
|
||||||
|
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
|
||||||
|
GARBAGE = pg.transform.scale(GARBAGE_IMG, (50, 50))
|
||||||
|
field_array_2.append(GARBAGE)
|
||||||
|
imgpath_array[i][j] = (
|
||||||
|
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
|
||||||
|
|
||||||
temp_priority.append(100)
|
temp_priority.append(100)
|
||||||
request_list.append(Request(
|
request_list.append(Request(
|
||||||
i*50,j*50, #lokacja
|
i * 50, j * 50, # lokacja
|
||||||
random.randint(0,3), #typ śmieci
|
random.randint(0, 3), # typ śmieci
|
||||||
random.random(), #objętość śmieci
|
random.random(), # objętość śmieci
|
||||||
random.randint(0,30), #ostatni odbiór
|
random.randint(0, 30), # ostatni odbiór
|
||||||
random.randint(0,1), #czy opłacone w terminie
|
random.randint(0, 1), # czy opłacone w terminie
|
||||||
random.random() * 10, #intensywność odoru
|
random.random() * 10, # intensywność odoru
|
||||||
random.random() * 50 #waga śmieci
|
random.random() * 50 # waga śmieci
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
field_array_2.append(GRASS)
|
field_array_2.append(BASE)
|
||||||
temp_priority.append(1)
|
temp_priority.append(1000)
|
||||||
field_array_1.append(field_array_2)
|
field_array_1.append(field_array_2)
|
||||||
field_array_2 = []
|
field_array_2 = []
|
||||||
field_priority.append(temp_priority)
|
field_priority.append(temp_priority)
|
||||||
return field_array_1, field_priority, request_list
|
return field_array_1, field_priority, request_list, imgpath_array
|
||||||
|
BIN
model_training/garbage_model.pth
Normal file
177
model_training/main.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch.utils.data import Dataset, random_split, DataLoader
|
||||||
|
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.models as models
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.manual_seed(42)
|
||||||
|
# input_size = 49152
|
||||||
|
# hidden_sizes = [64, 128]
|
||||||
|
# output_size = 10
|
||||||
|
|
||||||
|
classes = os.listdir('./train_dataset')
|
||||||
|
print(classes)
|
||||||
|
mean = [0.6908, 0.6612, 0.6218]
|
||||||
|
std = [0.1947, 0.1926, 0.2086]
|
||||||
|
|
||||||
|
training_dataset_path = './train_dataset'
|
||||||
|
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
||||||
|
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
|
||||||
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
testing_dataset_path = './test_dataset'
|
||||||
|
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
|
||||||
|
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
|
||||||
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
|
# Mean and Standard Deviation approximations
|
||||||
|
def get_mean_and_std(loader):
|
||||||
|
mean = 0.
|
||||||
|
std = 0.
|
||||||
|
total_images_count = 0
|
||||||
|
for images, _ in loader:
|
||||||
|
image_count_in_a_batch = images.size(0)
|
||||||
|
#print(images.shape)
|
||||||
|
images = images.view(image_count_in_a_batch, images.size(1), -1)
|
||||||
|
#print(images.shape)
|
||||||
|
mean += images.mean(2).sum(0)
|
||||||
|
std += images.std(2).sum(0)
|
||||||
|
total_images_count += image_count_in_a_batch
|
||||||
|
mean /= total_images_count
|
||||||
|
std /= total_images_count
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
print(get_mean_and_std(train_loader))
|
||||||
|
|
||||||
|
# Show images with applied transformations
|
||||||
|
def show_transformed_images(dataset):
|
||||||
|
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
|
||||||
|
batch = next(iter(loader))
|
||||||
|
images, labels = batch
|
||||||
|
|
||||||
|
grid = torchvision.utils.make_grid(images, nrow=3)
|
||||||
|
plt.figure(figsize=(11,11))
|
||||||
|
plt.imshow(np.transpose(grid, (1,2,0)))
|
||||||
|
print('labels: ', labels)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
show_transformed_images(train_dataset)
|
||||||
|
|
||||||
|
# Neural network training:
|
||||||
|
def set_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
dev = "cuda:0"
|
||||||
|
else:
|
||||||
|
dev = "cpu"
|
||||||
|
return torch.device(dev)
|
||||||
|
|
||||||
|
|
||||||
|
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
|
||||||
|
device = set_device()
|
||||||
|
best_acc = 0
|
||||||
|
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
print("Epoch number %d " % (epoch+1))
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
running_correct = 0.0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for data in train_loader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
total += labels.size(0)
|
||||||
|
|
||||||
|
# Back propagation
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
running_loss += loss.item()
|
||||||
|
running_correct += (labels==predicted).sum().item()
|
||||||
|
|
||||||
|
epoch_loss = running_loss/len(train_loader)
|
||||||
|
epoch_acc = 100.00 * running_correct / total
|
||||||
|
|
||||||
|
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
|
||||||
|
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
|
||||||
|
|
||||||
|
if(test_dataset_acc > best_acc):
|
||||||
|
best_acc = test_dataset_acc
|
||||||
|
save_checkpoint(model, epoch, optimizer, best_acc)
|
||||||
|
|
||||||
|
print("Finished")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model_on_test_set(model, test_loader):
|
||||||
|
model.eval()
|
||||||
|
predicted_correctly_on_epoch = 0
|
||||||
|
total = 0
|
||||||
|
device = set_device()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in test_loader:
|
||||||
|
images, labels = data
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
total += labels.size(0)
|
||||||
|
|
||||||
|
outputs = model(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
predicted_correctly_on_epoch += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
|
||||||
|
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
|
||||||
|
|
||||||
|
return epoch_acc
|
||||||
|
|
||||||
|
|
||||||
|
# Saving the checkpoint:
|
||||||
|
def save_checkpoint(model, epoch, optimizer, best_acc):
|
||||||
|
state = {
|
||||||
|
'epoch': epoch+1,
|
||||||
|
'model': model.state_dict(),
|
||||||
|
'best_accuracy': best_acc,
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
torch.save(state, 'model_best_checkpoint.zip')
|
||||||
|
|
||||||
|
|
||||||
|
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
|
||||||
|
num_features = resnet18_model.fc.in_features
|
||||||
|
number_of_classes = 4
|
||||||
|
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||||||
|
device = set_device()
|
||||||
|
resnet_18_model = resnet18_model.to(device)
|
||||||
|
loss_fn = nn.CrossEntropyLoss() #criterion
|
||||||
|
|
||||||
|
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
|
||||||
|
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
|
||||||
|
|
||||||
|
|
||||||
|
# Saving the model:
|
||||||
|
checkpoint = torch.load('model_best_checkpoint.pth.zip')
|
||||||
|
|
||||||
|
resnet18_model = models.resnet18()
|
||||||
|
num_features = resnet18_model.fc.in_features
|
||||||
|
number_of_classes = 4
|
||||||
|
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
|
||||||
|
resnet18_model.load_state_dict(checkpoint['model'])
|
||||||
|
|
||||||
|
torch.save(resnet18_model, 'garbage_model.pth')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
BIN
model_training/test.jpg
Normal file
After Width: | Height: | Size: 162 KiB |
BIN
model_training/test_dataset/glass/glass (1).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (10).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (100).jpg
Normal file
After Width: | Height: | Size: 2.0 KiB |
BIN
model_training/test_dataset/glass/glass (11).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (12).jpg
Normal file
After Width: | Height: | Size: 2.4 KiB |
BIN
model_training/test_dataset/glass/glass (13).jpg
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
model_training/test_dataset/glass/glass (14).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (15).jpg
Normal file
After Width: | Height: | Size: 2.8 KiB |
BIN
model_training/test_dataset/glass/glass (16).jpg
Normal file
After Width: | Height: | Size: 1.8 KiB |
BIN
model_training/test_dataset/glass/glass (17).jpg
Normal file
After Width: | Height: | Size: 3.8 KiB |
BIN
model_training/test_dataset/glass/glass (18).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (19).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (2).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (20).jpg
Normal file
After Width: | Height: | Size: 4.9 KiB |
BIN
model_training/test_dataset/glass/glass (21).jpg
Normal file
After Width: | Height: | Size: 3.9 KiB |
BIN
model_training/test_dataset/glass/glass (22).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (23).jpg
Normal file
After Width: | Height: | Size: 3.4 KiB |
BIN
model_training/test_dataset/glass/glass (24).jpg
Normal file
After Width: | Height: | Size: 2.2 KiB |
BIN
model_training/test_dataset/glass/glass (25).jpg
Normal file
After Width: | Height: | Size: 2.4 KiB |
BIN
model_training/test_dataset/glass/glass (26).jpg
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
model_training/test_dataset/glass/glass (27).jpg
Normal file
After Width: | Height: | Size: 3.4 KiB |
BIN
model_training/test_dataset/glass/glass (28).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (29).jpg
Normal file
After Width: | Height: | Size: 4.8 KiB |
BIN
model_training/test_dataset/glass/glass (3).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (30).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (31).jpg
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
model_training/test_dataset/glass/glass (32).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (33).jpg
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
model_training/test_dataset/glass/glass (34).jpg
Normal file
After Width: | Height: | Size: 5.4 KiB |
BIN
model_training/test_dataset/glass/glass (35).jpg
Normal file
After Width: | Height: | Size: 2.9 KiB |
BIN
model_training/test_dataset/glass/glass (36).jpg
Normal file
After Width: | Height: | Size: 2.5 KiB |
BIN
model_training/test_dataset/glass/glass (37).jpg
Normal file
After Width: | Height: | Size: 4.4 KiB |
BIN
model_training/test_dataset/glass/glass (38).jpg
Normal file
After Width: | Height: | Size: 3.7 KiB |
BIN
model_training/test_dataset/glass/glass (39).jpg
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
model_training/test_dataset/glass/glass (4).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (40).jpg
Normal file
After Width: | Height: | Size: 3.6 KiB |
BIN
model_training/test_dataset/glass/glass (41).jpg
Normal file
After Width: | Height: | Size: 4.7 KiB |
BIN
model_training/test_dataset/glass/glass (42).jpg
Normal file
After Width: | Height: | Size: 3.7 KiB |
BIN
model_training/test_dataset/glass/glass (43).jpg
Normal file
After Width: | Height: | Size: 3.4 KiB |
BIN
model_training/test_dataset/glass/glass (44).jpg
Normal file
After Width: | Height: | Size: 7.2 KiB |
BIN
model_training/test_dataset/glass/glass (45).jpg
Normal file
After Width: | Height: | Size: 1.8 KiB |
BIN
model_training/test_dataset/glass/glass (46).jpg
Normal file
After Width: | Height: | Size: 4.0 KiB |
BIN
model_training/test_dataset/glass/glass (47).jpg
Normal file
After Width: | Height: | Size: 5.2 KiB |
BIN
model_training/test_dataset/glass/glass (48).jpg
Normal file
After Width: | Height: | Size: 4.5 KiB |
BIN
model_training/test_dataset/glass/glass (49).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (5).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (50).jpg
Normal file
After Width: | Height: | Size: 5.0 KiB |
BIN
model_training/test_dataset/glass/glass (51).jpg
Normal file
After Width: | Height: | Size: 5.3 KiB |
BIN
model_training/test_dataset/glass/glass (52).jpg
Normal file
After Width: | Height: | Size: 2.8 KiB |
BIN
model_training/test_dataset/glass/glass (53).jpg
Normal file
After Width: | Height: | Size: 6.3 KiB |
BIN
model_training/test_dataset/glass/glass (54).jpg
Normal file
After Width: | Height: | Size: 4.2 KiB |
BIN
model_training/test_dataset/glass/glass (55).jpg
Normal file
After Width: | Height: | Size: 3.5 KiB |
BIN
model_training/test_dataset/glass/glass (56).jpg
Normal file
After Width: | Height: | Size: 5.9 KiB |
BIN
model_training/test_dataset/glass/glass (57).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (58).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (59).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (6).jpg
Normal file
After Width: | Height: | Size: 2.3 KiB |
BIN
model_training/test_dataset/glass/glass (60).jpg
Normal file
After Width: | Height: | Size: 5.0 KiB |
BIN
model_training/test_dataset/glass/glass (61).jpg
Normal file
After Width: | Height: | Size: 3.8 KiB |
BIN
model_training/test_dataset/glass/glass (62).jpg
Normal file
After Width: | Height: | Size: 4.4 KiB |
BIN
model_training/test_dataset/glass/glass (63).jpg
Normal file
After Width: | Height: | Size: 5.0 KiB |
BIN
model_training/test_dataset/glass/glass (64).jpg
Normal file
After Width: | Height: | Size: 4.2 KiB |
BIN
model_training/test_dataset/glass/glass (65).jpg
Normal file
After Width: | Height: | Size: 4.3 KiB |
BIN
model_training/test_dataset/glass/glass (66).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (67).jpg
Normal file
After Width: | Height: | Size: 4.9 KiB |
BIN
model_training/test_dataset/glass/glass (68).jpg
Normal file
After Width: | Height: | Size: 3.0 KiB |
BIN
model_training/test_dataset/glass/glass (69).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (7).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (70).jpg
Normal file
After Width: | Height: | Size: 1.9 KiB |
BIN
model_training/test_dataset/glass/glass (71).jpg
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
model_training/test_dataset/glass/glass (72).jpg
Normal file
After Width: | Height: | Size: 3.6 KiB |
BIN
model_training/test_dataset/glass/glass (73).jpg
Normal file
After Width: | Height: | Size: 3.2 KiB |
BIN
model_training/test_dataset/glass/glass (74).jpg
Normal file
After Width: | Height: | Size: 2.6 KiB |
BIN
model_training/test_dataset/glass/glass (75).jpg
Normal file
After Width: | Height: | Size: 2.9 KiB |
BIN
model_training/test_dataset/glass/glass (76).jpg
Normal file
After Width: | Height: | Size: 2.2 KiB |
BIN
model_training/test_dataset/glass/glass (77).jpg
Normal file
After Width: | Height: | Size: 3.7 KiB |
BIN
model_training/test_dataset/glass/glass (78).jpg
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
model_training/test_dataset/glass/glass (79).jpg
Normal file
After Width: | Height: | Size: 3.1 KiB |
BIN
model_training/test_dataset/glass/glass (8).jpg
Normal file
After Width: | Height: | Size: 2.0 KiB |