Compare commits

..

No commits in common. "neural_network" and "master" have entirely different histories.

4420 changed files with 58 additions and 416 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 192 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 187 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

View File

Before

Width:  |  Height:  |  Size: 9.3 KiB

After

Width:  |  Height:  |  Size: 9.3 KiB

22
collect
View File

@ -24,7 +24,7 @@ edge [fontname="helvetica"] ;
6 -> 10 ;
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
10 -> 11 ;
12 [label="garbage_type <= 2.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
11 -> 12 ;
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
12 -> 13 ;
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
15 -> 16 ;
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
15 -> 17 ;
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
17 -> 18 ;
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
18 -> 19 ;
@ -54,11 +54,11 @@ edge [fontname="helvetica"] ;
23 -> 25 ;
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
25 -> 26 ;
27 [label="days_since_last_collection <= 22.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
25 -> 27 ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
27 -> 28 ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
27 -> 29 ;
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
@ -88,14 +88,18 @@ edge [fontname="helvetica"] ;
40 -> 42 ;
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
42 -> 43 ;
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
42 -> 44 ;
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
44 -> 45 ;
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
44 -> 46 ;
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
46 -> 47 ;
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
46 -> 48 ;
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
48 -> 49 ;
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
48 -> 50 ;
}

Binary file not shown.

View File

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

@ -1,16 +1,11 @@
from heuristicfn import heuristicfn
FIELDWIDTH = 50
TURN_FUEL_COST = 10
MOVE_FUEL_COST = 200
MAX_FUEL = 20000
MAX_SPACE = 5
MAX_WEIGHT = 400
MAX_WEIGHT_GLASS = 100
MAX_WEIGHT_MIXED = 100
MAX_WEIGHT_PAPER = 100
MAX_WEIGHT_PLASTIC = 100
MAX_WEIGHT = 200
class GarbageTruck:
@ -23,10 +18,6 @@ class GarbageTruck:
self.fuel = MAX_FUEL
self.free_space = MAX_SPACE
self.weight_capacity = MAX_WEIGHT
self.weight_capacity_glass = MAX_WEIGHT_GLASS
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
self.weight_capacity_paper = MAX_WEIGHT_PAPER
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
self.rect = rect
self.orientation = orientation
self.request_list = request_list #lista domów do odwiedzenia
@ -87,33 +78,10 @@ class GarbageTruck:
def collect(self, garbage_type):
def collect(self):
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
self.fuel = MAX_FUEL
self.free_space = MAX_SPACE
self.weight_capacity = MAX_WEIGHT
self.weight_capacity_plastic = MAX_WEIGHT_PLASTIC
self.weight_capacity_mixed = MAX_WEIGHT_MIXED
self.weight_capacity_glass = MAX_WEIGHT_GLASS
self.weight_capacity_paper = MAX_WEIGHT_PAPER
request = self.request_list[0]
if garbage_type == "glass":
if request.weight > self.weight_capacity_glass:
return 1
self.weight_capacity_glass -= request.weight
elif garbage_type == "mixed":
if request.weight > self.weight_capacity_mixed:
return 1
self.weight_capacity_mixed -= request.weight
elif garbage_type == "paper":
if request.weight > self.weight_capacity_paper:
return 1
self.weight_capacity_paper -= request.weight
elif garbage_type == "plastic":
if request.weight > self.weight_capacity_plastic:
return 1
self.weight_capacity_plastic -= request.weight
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}, glass_capacity: {self.weight_capacity_glass}, mixed_capacity: {self.weight_capacity_mixed}, paper_capacity: {self.weight_capacity_paper}, plastic_capacity: {self.weight_capacity_plastic}')
return 0
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
pass

View File

Before

Width:  |  Height:  |  Size: 26 KiB

After

Width:  |  Height:  |  Size: 26 KiB

View File

@ -1,2 +1,3 @@
def heuristicfn(startx, starty, goalx, goaly):
return abs(startx - goalx) + abs(starty - goaly)
# return pow(((startx//50)-(starty//50)),2) + pow(((goalx//50)-(goaly//50)),2)

View File

@ -1,44 +0,0 @@
import torch
import torchvision
import torchvision.transforms as transforms
import PIL.Image as Image
import os
def classify(image_path):
model = torch.load('./model_training/garbage_model.pth')
mean = [0.6908, 0.6612, 0.6218]
std = [0.1947, 0.1926, 0.2086]
classes = [
"glass",
"mixed",
"paper",
"plastic",
]
image_transforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
model = model.eval()
image = Image.open(image_path)
image = image_transforms(image).float()
image = image.unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
label = os.path.basename(os.path.dirname(image_path))
prediction = classes[predicted.item()]
print(f"predicted: {prediction}")
if label == prediction:
print("predicted correctly.")
else:
print("predicted incorrectly.")
return prediction
# classify("./model_training/test.jpg")

61
main.py
View File

@ -1,6 +1,7 @@
import pygame
from treelearn import treelearn
import loadmodel
from astar import astar
from state import State
import time
@ -8,7 +9,6 @@ from garbage_truck import GarbageTruck
from heuristicfn import heuristicfn
from map import randomize_map
pygame.init()
WIDTH, HEIGHT = 800, 800
window = pygame.display.set_mode((WIDTH, HEIGHT))
@ -18,18 +18,14 @@ AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
FPS = 10
FIELDCOUNT = 16
FIELDWIDTH = 50
BASE_IMG = pygame.image.load("Tiles/Base.jpg")
BASE = pygame.transform.scale(BASE_IMG, (50, 50))
def draw_window(agent, fields, flip, turn):
GRASS_IMG = pygame.image.load("grass.png")
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
def draw_window(agent, fields, flip):
if flip:
direction = pygame.transform.flip(AGENT, True, False)
if turn:
direction = pygame.transform.rotate(AGENT, -90)
else:
direction = pygame.transform.flip(AGENT, False, False)
if turn:
direction = pygame.transform.rotate(AGENT, 90)
for i in range(16):
for j in range(16):
window.blit(fields[i][j], (i * 50, j * 50))
@ -41,63 +37,40 @@ def main():
clf = treelearn()
clock = pygame.time.Clock()
run = True
fields, priority_array, request_list, imgpath_array = randomize_map()
fields, priority_array, request_list = randomize_map()
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
low_space = 0
while run:
clock.tick(FPS)
for event in pygame.event.get():
if event.type == pygame.QUIT:
run = False
draw_window(agent, fields, False, False) # false = kierunek east (domyslny), true = west
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
x, y = agent.next_destination()
if x == agent.rect.x and y == agent.rect.y:
print('out of jobs')
break
if low_space == 1:
x, y = 0, 0
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation,
priority_array[agent.rect.x//50][agent.rect.y//50],
heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[agent.rect.x//50][agent.rect.y//50], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
for interm in steps:
if interm.action == 'LEFT':
agent.turn_left()
if agent.orientation == 0:
draw_window(agent, fields, False, False)
elif agent.orientation == 2:
draw_window(agent, fields, True, False)
elif agent.orientation == 1:
draw_window(agent, fields, True, True)
else:
draw_window(agent, fields, False, True)
draw_window(agent, fields, True)
elif interm.action == 'RIGHT':
agent.turn_right()
if agent.orientation == 0:
draw_window(agent, fields, False, False)
elif agent.orientation == 2:
draw_window(agent, fields, True, False)
elif agent.orientation == 1:
draw_window(agent, fields, True, True)
else:
draw_window(agent, fields, False, True)
draw_window(agent, fields, False)
elif interm.action == 'FORWARD':
agent.forward()
if agent.orientation == 0:
draw_window(agent, fields, False, False)
draw_window(agent, fields, False)
elif agent.orientation == 2:
draw_window(agent, fields, True, False)
elif agent.orientation == 1:
draw_window(agent, fields, True, True)
draw_window(agent, fields, True)
else:
draw_window(agent, fields, False, True)
draw_window(agent, fields, False)
time.sleep(0.3)
if (agent.rect.x // 50 != 0) or (agent.rect.y // 50 != 0):
garbage_type = loadmodel.classify(imgpath_array[agent.rect.x // 50][agent.rect.y // 50])
low_space = agent.collect(garbage_type)
fields[agent.rect.x//50][agent.rect.y//50] = BASE
priority_array[agent.rect.x//50][agent.rect.y//50] = 100
agent.collect()
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
priority_array[agent.rect.x//50][agent.rect.y//50] = 1
time.sleep(0.5)
pygame.quit()

131
map.py
View File

@ -1,127 +1,44 @@
import pygame as pg
import random
import pygame, random
from request import Request
DIRT_IMG = pygame.image.load("dirt.jpg")
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
GRASS_IMG = pygame.image.load("grass.png")
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
SAND_IMG = pygame.image.load("sand.jpeg")
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
COBBLE_IMG = pygame.image.load("cobble.jpeg")
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
STRAIGHT_IMG = pg.image.load("Tiles/Straight.jpg")
STRAIGHT_VERTICAL = pg.transform.scale(STRAIGHT_IMG, (50, 50))
STRAIGHT_HORIZONTAL = pg.transform.rotate(STRAIGHT_VERTICAL, 270)
BASE_IMG = pg.image.load("Tiles/Base.jpg")
BASE = pg.transform.scale(BASE_IMG, (50, 50))
BEND_IMG = pg.image.load("Tiles/Bend.jpg")
BEND1 = pg.transform.scale(BEND_IMG, (50, 50))
BEND2 = pg.transform.rotate(BEND1, 90)
BEND3 = pg.transform.rotate(pg.transform.flip(pg.transform.rotate(BEND1, 180), True, True), 180)
BEND4 = pg.transform.rotate(BEND1, -90)
INTERSECTION_IMG = pg.image.load("Tiles/Intersection.jpg")
INTERSECTION = pg.transform.scale(INTERSECTION_IMG, (50, 50))
JUNCTION_IMG = pg.image.load("Tiles/Junction.jpg")
JUNCTION_SOUTH = pg.transform.scale(JUNCTION_IMG, (50, 50))
JUNCTION_NORTH = pg.transform.rotate(pg.transform.flip(JUNCTION_SOUTH, True, False), 180)
JUNCTION_EAST = pg.transform.rotate(JUNCTION_SOUTH, -90)
JUNCTION_WEST = pg.transform.rotate(JUNCTION_SOUTH, 90)
END_IMG = pg.image.load("Tiles/End.jpg")
END1 = pg.transform.flip(pg.transform.rotate(pg.transform.scale(END_IMG, (50, 50)), 180), False, True)
END2 = pg.transform.rotate(END1, 90)
DIRT_IMG = pg.image.load("Tiles/dirt.jpg")
DIRT = pg.transform.scale(DIRT_IMG, (50, 50))
GRASS_IMG = pg.image.load("Tiles/grass.png")
GRASS = pg.transform.scale(GRASS_IMG, (50, 50))
SAND_IMG = pg.image.load("Tiles/sand.jpeg")
SAND = pg.transform.scale(SAND_IMG, (50, 50))
COBBLE_IMG = pg.image.load("Tiles/cobble.jpeg")
COBBLE = pg.transform.scale(COBBLE_IMG, (50, 50))
def randomize_map(): # tworzenie mapy z losowymi polami
def randomize_map(): # tworzenie mapy z losowymi polami
request_list = []
field_array_1 = []
field_array_2 = []
imgpath_array = [[0 for x in range(16)] for x in range(16)]
field_priority = []
map_array = [['b', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'sh', 'jw', 'sh', 'sh', 'sh', 'b3', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'jn', 'g'],
['sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['b1', 'sh', 'jw', 'sh', 'sh', 'jn', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['g', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'gr', 'gr', 'g', 'js', 'sh', 'sh', 'sh', 'jn', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['gr', 'g', 'js', 'sh', 'sh', 'i', 'sh', 'sh', 'sh', 'sh', 'jn', 'g', 'gr', 'g', 'sv', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', ' g', 'g', 'g', 'sv', 'g', 'gr', 'g', 'sv', 'g'],
['gr', 'g', 'sv', 'g', 'g', 'sv', 'g', 'g', 'g', 'g', 'sv', 'g', 'g', 'g', 'sv', 'g'],
['gr', 'g', 'b1', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'sh', 'je', 'sh', 'sh', 'sh', 'b4', 'g'],
]
for i in range(16):
temp_priority = []
for j in range(16):
if map_array[i][j] == 'b':
field_array_2.append(BASE)
if i in (0, 1) and j in (0, 1):
field_array_2.append(GRASS)
temp_priority.append(1)
elif map_array[i][j] == 'b3':
field_array_2.append(BEND3)
temp_priority.append(1)
elif map_array[i][j] == 'b4':
field_array_2.append(BEND4)
temp_priority.append(1)
elif map_array[i][j] == 'b1':
field_array_2.append(BEND1)
temp_priority.append(1)
elif map_array[i][j] == 'sh':
field_array_2.append(STRAIGHT_VERTICAL)
temp_priority.append(1)
elif map_array[i][j] == 'sv':
field_array_2.append(STRAIGHT_HORIZONTAL)
temp_priority.append(1)
elif map_array[i][j] == 'i':
field_array_2.append(INTERSECTION)
temp_priority.append(1)
elif map_array[i][j] == 'je':
field_array_2.append(JUNCTION_EAST)
temp_priority.append(1)
elif map_array[i][j] == 'jw':
field_array_2.append(JUNCTION_WEST)
temp_priority.append(1)
elif map_array[i][j] == 'js':
field_array_2.append(JUNCTION_SOUTH)
temp_priority.append(1)
elif map_array[i][j] == 'jn':
field_array_2.append(JUNCTION_NORTH)
temp_priority.append(1)
elif map_array[i][j] == 'gr':
field_array_2.append(BASE)
temp_priority.append(1000)
else:
prob = random.uniform(0, 100)
if 0 <= prob <= 20:
garbage_type = random.choice(['glass', 'mixed', 'paper', 'plastic'])
garbage_image_number = random.randrange(1, 100)
GARBAGE_IMG = pg.image.load(
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
GARBAGE = pg.transform.scale(GARBAGE_IMG, (50, 50))
field_array_2.append(GARBAGE)
imgpath_array[i][j] = (
f"./model_training/test_dataset/{garbage_type}/{garbage_type} ({str(garbage_image_number)}).jpg")
if 0 <= prob <= 12:
field_array_2.append(COBBLE)
temp_priority.append(100)
request_list.append(Request(
i * 50, j * 50, # lokacja
random.randint(0, 3), # typ śmieci
random.random(), # objętość śmieci
random.randint(0, 30), # ostatni odbiór
random.randint(0, 1), # czy opłacone w terminie
random.random() * 10, # intensywność odoru
random.random() * 50 # waga śmieci
i*50,j*50, #lokacja
random.randint(0,3), #typ śmieci
random.random(), #objętość śmieci
random.randint(0,30), #ostatni odbiór
random.randint(0,1), #czy opłacone w terminie
random.random() * 10, #intensywność odoru
random.random() * 50 #waga śmieci
))
else:
field_array_2.append(BASE)
temp_priority.append(1000)
field_array_2.append(GRASS)
temp_priority.append(1)
field_array_1.append(field_array_2)
field_array_2 = []
field_priority.append(temp_priority)
return field_array_1, field_priority, request_list, imgpath_array
return field_array_1, field_priority, request_list

Binary file not shown.

View File

@ -1,177 +0,0 @@
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.transforms import Compose, Lambda, ToTensor, Resize, CenterCrop, Normalize
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
def main():
torch.manual_seed(42)
# input_size = 49152
# hidden_sizes = [64, 128]
# output_size = 10
classes = os.listdir('./train_dataset')
print(classes)
mean = [0.6908, 0.6612, 0.6218]
std = [0.1947, 0.1926, 0.2086]
training_dataset_path = './train_dataset'
training_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
train_dataset = torchvision.datasets.ImageFolder(root=training_dataset_path, transform=training_transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
testing_dataset_path = './test_dataset'
testing_transforms = transforms.Compose([Resize((128,128)), ToTensor(), Normalize(torch.Tensor(mean), torch.Tensor(std))])
test_dataset = torchvision.datasets.ImageFolder(root=testing_dataset_path, transform=testing_transforms)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
# Mean and Standard Deviation approximations
def get_mean_and_std(loader):
mean = 0.
std = 0.
total_images_count = 0
for images, _ in loader:
image_count_in_a_batch = images.size(0)
#print(images.shape)
images = images.view(image_count_in_a_batch, images.size(1), -1)
#print(images.shape)
mean += images.mean(2).sum(0)
std += images.std(2).sum(0)
total_images_count += image_count_in_a_batch
mean /= total_images_count
std /= total_images_count
return mean, std
print(get_mean_and_std(train_loader))
# Show images with applied transformations
def show_transformed_images(dataset):
loader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)
batch = next(iter(loader))
images, labels = batch
grid = torchvision.utils.make_grid(images, nrow=3)
plt.figure(figsize=(11,11))
plt.imshow(np.transpose(grid, (1,2,0)))
print('labels: ', labels)
plt.show()
show_transformed_images(train_dataset)
# Neural network training:
def set_device():
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
return torch.device(dev)
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epochs):
device = set_device()
best_acc = 0
for epoch in range(n_epochs):
print("Epoch number %d " % (epoch+1))
model.train()
running_loss = 0.0
running_correct = 0.0
total = 0
for data in train_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
# Back propagation
optimizer.zero_grad()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_correct += (labels==predicted).sum().item()
epoch_loss = running_loss/len(train_loader)
epoch_acc = 100.00 * running_correct / total
print(" - Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))
test_dataset_acc = evaluate_model_on_test_set(model, test_loader)
if(test_dataset_acc > best_acc):
best_acc = test_dataset_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print("Finished")
return model
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
total = 0
device = set_device()
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predicted_correctly_on_epoch += (predicted == labels).sum().item()
epoch_acc = 100.0 * predicted_correctly_on_epoch / total
print(" - Testing dataset. Got %d out of %d images correctly (%.3f%%)" % (predicted_correctly_on_epoch, total, epoch_acc))
return epoch_acc
# Saving the checkpoint:
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch+1,
'model': model.state_dict(),
'best_accuracy': best_acc,
'optimizer': optimizer.state_dict(),
}
torch.save(state, 'model_best_checkpoint.zip')
resnet18_model = models.resnet18(pretrained=True) #Increase n_epochs if False
num_features = resnet18_model.fc.in_features
number_of_classes = 4
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
device = set_device()
resnet_18_model = resnet18_model.to(device)
loss_fn = nn.CrossEntropyLoss() #criterion
optimizer = optim.SGD(resnet_18_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)
train_nn(resnet_18_model, train_loader, test_loader, loss_fn, optimizer, 5)
# Saving the model:
checkpoint = torch.load('model_best_checkpoint.pth.zip')
resnet18_model = models.resnet18()
num_features = resnet18_model.fc.in_features
number_of_classes = 4
resnet18_model.fc = nn.Linear(num_features, number_of_classes)
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'garbage_model.pth')
if __name__ == "__main__":
main()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 162 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Some files were not shown because too many files have changed in this diff Show More