Compare commits

...

3 Commits

4420 changed files with 416 additions and 58 deletions

BIN
Tiles/Base.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 209 KiB

BIN
Tiles/Bend.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

BIN
Tiles/End.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

BIN
Tiles/Intersection.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

BIN
Tiles/Junction.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

BIN
Tiles/Straight.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

View File

Before

Width:  |  Height:  |  Size: 9.3 KiB

After

Width:  |  Height:  |  Size: 9.3 KiB

View File

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

Before

Width:  |  Height:  |  Size: 26 KiB

After

Width:  |  Height:  |  Size: 26 KiB

View File

Before

Width:  |  Height:  |  Size: 9.8 KiB

After

Width:  |  Height:  |  Size: 9.8 KiB

22
collect
View File

@ -24,7 +24,7 @@ edge [fontname="helvetica"] ;
6 -> 10 ;
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
10 -> 11 ;
12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
12 [label="garbage_type <= 2.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
11 -> 12 ;
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
12 -> 13 ;
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
15 -> 16 ;
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
15 -> 17 ;
18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
17 -> 18 ;
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
18 -> 19 ;
@ -54,11 +54,11 @@ edge [fontname="helvetica"] ;
23 -> 25 ;
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
25 -> 26 ;
27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
27 [label="days_since_last_collection <= 22.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
25 -> 27 ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
27 -> 28 ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
27 -> 29 ;
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
@ -88,18 +88,14 @@ edge [fontname="helvetica"] ;
40 -> 42 ;
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
42 -> 43 ;
44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
42 -> 44 ;
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
44 -> 45 ;
46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
44 -> 46 ;
47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
46 -> 47 ;
48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
46 -> 48 ;
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
48 -> 49 ;
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
48 -> 50 ;
}

Binary file not shown.

View File

@ -1,11 +1,16 @@
from heuristicfn import heuristicfn
FIELDWIDTH = 50
TURN_FUEL_COST = 10
MOVE_FUEL_COST = 200
MAX_FUEL = 20000
MAX_SPACE = 5
MAX_WEIGHT = 200
MAX_WEIGHT = 400
MAX_WEIGHT_GLASS = 100
MAX_WEIGHT_MIXED = 100
MAX_WEIGHT_PAPER = 100
MAX_WEIGHT_PLASTIC = 100
class GarbageTruck:
@ -18,6 +23,10 @@ class GarbageTruck:
self.fuel = MAX_FUEL
self.free_space = MAX_SPACE
self.weight_capacity = MAX_WEIGHT
self.weight_capacity_glass = MAX_WEIGHT_GLASS
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
self.weight_capacity_paper = MAX_WEIGHT_PAPER
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
self.rect = rect
self.orientation = orientation
self.request_list = request_list #lista domów do odwiedzenia
@ -78,10 +87,33 @@ class GarbageTruck:
def collect(self):
def collect(self, garbage_type):
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
self.fuel = MAX_FUEL
self.free_space = MAX_SPACE
self.weight_capacity = MAX_WEIGHT
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
self.weight_capacity_glass = MAX_WEIGHT_GLASS
self.weight_capacity_paper = MAX_WEIGHT_PAPER
request = self.request_list[0]
if garbage_type == "glass":
if request.weight > self.weight_capacity_glass:
return 1
self.weight_capacity_glass -= request.weight
elif garbage_type == "mixed":
if request.weight > self.weight_capacity_mixed:
return 1
self.weight_capacity_mixed -= request.weight
elif garbage_type == "paper":
if request.weight > self.weight_capacity_paper:
return 1
self.weight_capacity_paper -= request.weight
elif garbage_type == "plastic":
if request.weight > self.weight_capacity_plastic:
return 1
self.weight_capacity_plastic -= request.weight
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}, glass_capacity: {self.weight_capacity_glass}, mixed_capacity: {self.weight_capacity_mixed}, paper_capacity: {self.weight_capacity_paper}, plastic_capacity: {self.weight_capacity_plastic}')
return 0
pass

View File

@ -1,3 +1,2 @@
def heuristicfn(startx, starty, goalx, goaly):
return abs(startx - goalx) + abs(starty - goaly)
# return pow(((startx//50)-(starty//50)),2) + pow(((goalx//50)-(goaly//50)),2)

44
loadmodel.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torchvision
import torchvision.transforms as transforms
import PIL.Image as Image
import os
def classify(image_path):
model = torch.load('./model_training/garbage_model.pth')
mean = [0.6908, 0.6612, 0.6218]
std = [0.1947, 0.1926, 0.2086]
classes = [
"glass",
"mixed",
"paper",
"plastic",
]
image_transforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
model = model.eval()
image = Image.open(image_path)
image = image_transforms(image).float()
image = image.unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
label = os.path.basename(os.path.dirname(image_path))
prediction = classes[predicted.item()]
print(f"predicted: {prediction}")
if label == prediction:
print("predicted correctly.")
else:
print("predicted incorrectly.")
return prediction
# classify("./model_training/test.jpg")

61
main.py
View File

@ -1,7 +1,6 @@
import pygame
from treelearn import treelearn
import loadmodel
from astar import astar
from state import State
import time
@ -9,6 +8,7 @@ from garbage_truck import GarbageTruck
from heuristicfn import heuristicfn
from map import randomize_map
pygame.init()
WIDTH, HEIGHT = 800, 800
window = pygame.display.set_mode((WIDTH, HEIGHT))
@ -18,14 +18,18 @@ AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
FPS = 10
FIELDCOUNT = 16
FIELDWIDTH = 50
BASE_IMG = pygame.image.load("Tiles/Base.jpg")
BASE = pygame.transform.scale(BASE_IMG, (50, 50))
GRASS_IMG = pygame.image.load("grass.png")
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
def draw_window(agent, fields, flip):
def draw_window(agent, fields, flip, turn):
if flip:
direction = pygame.transform.flip(AGENT, True, False)
if turn:
direction = pygame.transform.rotate(AGENT, -90)
else:
direction = pygame.transform.flip(AGENT, False, False)
if turn:
direction = pygame.transform.rotate(AGENT, 90)
for i in range(16):
for j in range(16):
window.blit(fields[i][j], (i * 50, j * 50))
@ -37,40 +41,63 @@ def main():
clf = treelearn()
clock = pygame.time.Clock()
run = True
fields, priority_array, request_list = randomize_map()
fields, priority_array, request_list, imgpath_array = randomize_map()
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
low_space = 0
while run:
clock.tick(FPS)
for event in pygame.event.get():
if event.type == pygame.QUIT:
run = False
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
draw_window(agent, fields, False, False) # false = kierunek east (domyslny), true = west
x, y = agent.next_destination()
if x == agent.rect.x and y == agent.rect.y:
print('out of jobs')
break
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[agent.rect.x//50][agent.rect.y//50], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
if low_space == 1:
x, y = 0, 0
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation,
priority_array[agent.rect.x//50][agent.rect.y//50],
heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
for interm in steps:
if interm.action == 'LEFT':
agent.turn_left()
draw_window(agent, fields, True)
if agent.orientation == 0:
draw_window(agent, fields, False, False)
elif agent.orientation == 2:
draw_window(agent, fields, True, False)
elif agent.orientation == 1:
draw_window(agent, fields, True, True)
else:
draw_window(agent, fields, False, True)
elif interm.action == 'RIGHT':
agent.turn_right()
draw_window(agent, fields, False)
if agent.orientation == 0:
draw_window(agent, fields, False, False)
elif agent.orientation == 2:
draw_window(agent, fields, True, False)
elif agent.orientation == 1:
draw_window(agent, fields, True, True)
else:
draw_window(agent, fields, False, True)
elif interm.action == 'FORWARD':
agent.forward()
if agent.orientation == 0:
draw_window(agent, fields, False)
draw_window(agent, fields, False, False)
elif agent.orientation == 2:
draw_window(agent, fields, True)
draw_window(agent, fields, True, False)
elif agent.orientation == 1:
draw_window(agent, fields, True, True)
else:
draw_window(agent, fields, False)
draw_window(agent, fields, False, True)
time.sleep(0.3)
agent.collect()
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
priority_array[agent.rect.x//50][agent.rect.y//50] = 1
time.sleep(0.5)
if (agent.rect.x // 50 != 0) or (agent.rect.y // 50 != 0):
garbage_type = loadmodel.classify(imgpath_array[agent.rect.x // 50][agent.rect.y // 50])
low_space = agent.collect(garbage_type)
fields[agent.rect.x//50][agent.rect.y//50] = BASE
priority_array[agent.rect.x//50][agent.rect.y//50] = 100
time.sleep(0.5)
pygame.quit()

115
map.py
View File

@ -1,30 +1,113 @@
import pygame, random
import pygame as pg
import random
from request import Request
DIRT_IMG = pygame.image.load("dirt.jpg")
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
GRASS_IMG = pygame.image.load("grass.png")
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
SAND_IMG = pygame.image.load("sand.jpeg")
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
COBBLE_IMG = pygame.image.load("cobble.jpeg")
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
STRAIGHT_IMG = pg.image.load("Tiles/Straight.jpg")
STRAIGHT_VERTICAL = pg.transform.scale(STRAIGHT_IMG, (50, 50))
STRAIGHT_HORIZONTAL = pg.transform.rotate(STRAIGHT_VERTICAL, 270)
BASE_IMG = pg.image.load("Tiles/Base.jpg")
BASE = pg.transform.scale(BASE_IMG, (50, 50))
BEND_IMG = pg.image.load("Tiles/Bend.jpg")
BEND1 = pg.transform.scale(BEND_IMG, (50, 50))
BEND2 = pg.transform.rotate(BEND1, 90)
BEND3 = pg.transform.rotate(pg.transform.flip(pg.transform.rotate(BEND1, 180), True, True), 180)
BEND4 = pg.transform.rotate(BEND1, -90)
INTERSECTION_IMG = pg.image.load("Tiles/Intersection.jpg")
INTERSECTION = pg.transform.scale(INTERSECTION_IMG, (50, 50))
JUNCTION_IMG = pg.image.load("Tiles/Junction.jpg")
JUNCTION_SOUTH = pg.transform.scale(JUNCTION_IMG, (50, 50))
JUNCTION_NORTH = pg.transform.rotate(pg.transform.flip(JUNCTION_SOUTH, True, False), 180)
JUNCTION_EAST = pg.transform.rotate(JUNCTION_SOUTH, -90)
JUNCTION_WEST = pg.transform.rotate(JUNCTION_SOUTH, 90)
END_IMG = pg.image.load("Tiles/End.jpg")
END1 = pg.transform.flip(pg.transform.rotate(pg.transform.scale(END_IMG, (50, 50)), 180), False, True)
END2 = pg.transform.rotate(END1, 90)
DIRT_IMG = pg.image.load("Tiles/dirt.jpg")
DIRT = pg.transform.scale(DIRT_IMG, (50, 50))
GRASS_IMG = pg.image.load("Tiles/grass.png")
GRASS = pg.transform.scale(GRASS_IMG, (50, 50))
SAND_IMG = pg.image.load("Tiles/sand.jpeg")
SAND = pg.transform.scale(SAND_IMG, (50, 50))
COBBLE_IMG = pg.image.load("Tiles/cobble.jpeg")
COBBLE = pg.transform.scale(COBBLE_IMG, (50, 50))
def randomize_map(): # tworzenie mapy z losowymi polami
request_list = []
field_array_1 = []
field_array_2 = []
imgpath_array = [[0 for x in range(16)] for x in range(16)]
field_priority = []
map_array = [['b', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'b3', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'jn', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['b1', 'sh', 'jw', 'sh', 'sh', 'jn', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['g', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'jn', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['gr', 'g', 'js', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', ' g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['gr', 'g', 'b1', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'b4', 'g'],
]
for i in range(16):
temp_priority = []
for j in range(16):
if i in (0, 1) and j in (0, 1):
field_array_2.append(GRASS)
if map_array[i][j] == 'b':
field_array_2.append(BASE)
temp_priority.append(1)
elif map_array[i][j] == 'b3':
field_array_2.append(BEND3)
temp_priority.append(1)
elif map_array[i][j] == 'b4':
field_array_2.append(BEND4)
temp_priority.append(1)
elif map_array[i][j] == 'b1':
field_array_2.append(BEND1)
temp_priority.append(1)
elif map_array[i][j] == 'sh':
field_array_2.append(STRAIGHT_VERTICAL)
temp_priority.append(1)
elif map_array[i][j] == 'sv':
field_array_2.append(STRAIGHT_HORIZONTAL)
temp_priority.append(1)
elif map_array[i][j] == 'i':
field_array_2.append(INTERSECTION)
temp_priority.append(1)
elif map_array[i][j] == 'je':
field_array_2.append(JUNCTION_EAST)
temp_priority.append(1)
elif map_array[i][j] == 'jw':
field_array_2.append(JUNCTION_WEST)
temp_priority.append(1)
elif map_array[i][j] == 'js':
field_array_2.append(JUNCTION_SOUTH)
temp_priority.append(1)
elif map_array[i][j] == 'jn':
field_array_2.append(JUNCTION_NORTH)
temp_priority.append(1)
elif map_array[i][j] == 'gr':
field_array_2.append(BASE)
temp_priority.append(1000)
else:
prob = random.uniform(0, 100)
if 0 <= prob <= 12:
field_array_2.append(COBBLE)
if 0 <= prob <= 20:
garbage_type = random.choice(['glass', 'mixed', 'paper', 'plastic'])
garbage_image_number = random.randrange(1, 100)
GARBAGE_IMG = pg.image.load(
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
GARBAGE = pg.transform.scale(GARBAGE_IMG, (50, 50))
field_array_2.append(GARBAGE)
imgpath_array[i][j] = (
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
temp_priority.append(100)
request_list.append(Request(
i * 50, j * 50, # lokacja
@ -36,9 +119,9 @@ def randomize_map(): # tworzenie mapy z losowymi polami
random.random() * 50 # waga śmieci
))
else:
field_array_2.append(GRASS)
temp_priority.append(1)
field_array_2.append(BASE)
temp_priority.append(1000)
field_array_1.append(field_array_2)
field_array_2 = []
field_priority.append(temp_priority)
return field_array_1, field_priority, request_list
return field_array_1, field_priority, request_list, imgpath_array

Binary file not shown.

177
model_training/main.py Normal file
View File

@ -0,0 +1,177 @@
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
def main():
torch.manual_seed(42)
# input_size = 49152
# hidden_sizes = [64, 128]
# output_size = 10
classes = os.listdir('./train_dataset')
print(classes)
mean = [0.6908, 0.6612, 0.6218]
std = [0.1947, 0.1926, 0.2086]
training_dataset_path = './train_dataset'
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
testing_dataset_path = './test_dataset'
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
# Mean and Standard Deviation approximations
def get_mean_and_std(loader):
mean = 0.
std = 0.
total_images_count = 0
for images, _ in loader:
image_count_in_a_batch = images.size(0)
#print(images.shape)
images = images.view(image_count_in_a_batch, images.size(1), -1)
#print(images.shape)
mean += images.mean(2).sum(0)
std += images.std(2).sum(0)
total_images_count += image_count_in_a_batch
mean /= total_images_count
std /= total_images_count
return mean, std
print(get_mean_and_std(train_loader))
# Show images with applied transformations
def show_transformed_images(dataset):
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
batch = next(iter(loader))
images, labels = batch
grid = torchvision.utils.make_grid(images, nrow=3)
plt.figure(figsize=(11,11))
plt.imshow(np.transpose(grid, (1,2,0)))
print('labels: ', labels)
plt.show()
show_transformed_images(train_dataset)
# Neural network training:
def set_device():
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
return torch.device(dev)
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
device = set_device()
best_acc = 0
for epoch in range(n_epochs):
print("Epoch number %d " % (epoch+1))
model.train()
running_loss = 0.0
running_correct = 0.0
total = 0
for data in train_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
# Back propagation
optimizer.zero_grad()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_correct += (labels==predicted).sum().item()
epoch_loss = running_loss/len(train_loader)
epoch_acc = 100.00 * running_correct / total
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
if(test_dataset_acc > best_acc):
best_acc = test_dataset_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print("Finished")
return model
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
total = 0
device = set_device()
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predicted_correctly_on_epoch += (predicted == labels).sum().item()
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
return epoch_acc
# Saving the checkpoint:
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch+1,
'model': model.state_dict(),
'best_accuracy': best_acc,
'optimizer': optimizer.state_dict(),
}
torch.save(state, 'model_best_checkpoint.zip')
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
num_features = resnet18_model.fc.in_features
number_of_classes = 4
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
device = set_device()
resnet_18_model = resnet18_model.to(device)
loss_fn = nn.CrossEntropyLoss() #criterion
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
# Saving the model:
checkpoint = torch.load('model_best_checkpoint.pth.zip')
resnet18_model = models.resnet18()
num_features = resnet18_model.fc.in_features
number_of_classes = 4
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'garbage_model.pth')
if __name__ == "__main__":
main()

BIN
model_training/test.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Some files were not shown because too many files have changed in this diff Show More