recognize garbage by image

This commit is contained in:
Pawel Felcyn 2023-06-05 10:17:28 +02:00
parent e4cd4e727c
commit 10f62ce529
2 changed files with 73 additions and 11 deletions

View File

@ -2,8 +2,9 @@ import joblib
from sklearn.calibration import LabelEncoder from sklearn.calibration import LabelEncoder
from agentActionType import AgentActionType from agentActionType import AgentActionType
import time import time
from garbage import GarbageType, RecognizedGarbage from garbage import Garbage, GarbageType, RecognizedGarbage
from garbageCan import GarbageCan from garbageCan import GarbageCan
from machine_learning.neuron_network import Net
from turnCar import turn_left_orientation, turn_right_orientation from turnCar import turn_left_orientation, turn_right_orientation
from garbageTruck import GarbageTruck from garbageTruck import GarbageTruck
from typing import Tuple, Dict from typing import Tuple, Dict
@ -13,6 +14,9 @@ from agentOrientation import AgentOrientation
import pygame import pygame
from bfs import find_path_to_nearest_can from bfs import find_path_to_nearest_can
from agentState import AgentState from agentState import AgentState
import torch
import torchvision.transforms as transforms
from PIL import Image
def collect_garbage(game_context: GameContext) -> None: def collect_garbage(game_context: GameContext) -> None:
@ -31,10 +35,18 @@ def collect_garbage(game_context: GameContext) -> None:
def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None: def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
loaded_model = joblib.load('machine_learning/model.pkl') loaded_model = joblib.load('machine_learning/model.pkl')
checkpoint = torch.load('machine_learning/model.pt')
if 'module' in list(checkpoint['model_state_dict'].keys())[0]:
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
else:
checkpoint = checkpoint['model_state_dict']
neuron_model = Net()
neuron_model.load_state_dict(checkpoint)
neuron_model.eval()
for garbage in can.garbage: for garbage in can.garbage:
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din] predicted_class = _recognize_by_image(garbage, neuron_model) if garbage.img is not None else _recognize_by_attributes(garbage, loaded_model)
encoded = attributes_to_floats(attributes)
predicted_class = loaded_model.predict([encoded])[0]
garbage_type: GarbageType = None garbage_type: GarbageType = None
if predicted_class == 'PAPER': if predicted_class == 'PAPER':
garbage_type = GarbageType.PAPER garbage_type = GarbageType.PAPER
@ -50,6 +62,40 @@ def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
recognized_garbage = RecognizedGarbage(garbage, garbage_type) recognized_garbage = RecognizedGarbage(garbage, garbage_type)
dust_car.sort_garbage(recognized_garbage) dust_car.sort_garbage(recognized_garbage)
def _recognize_by_image(garbage: Garbage, model: Net) -> str:
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
image = Image.open(garbage.img)
image = transform(image)
with torch.no_grad():
output = model(image.unsqueeze(0))
_, predicted = torch.max(output.data, 1)
return _convert_image_prediction(predicted)
def _convert_image_prediction(prediction: torch.Tensor) -> str:
item = prediction.item()
if item == 0:
return 'BIO'
if item == 1:
return 'GLASS'
if item == 2:
return 'MIXED'
if item == 3:
return 'PAPER'
if item == 4:
return "PLASTIC_AND_METAL"
print(type(prediction))
return None
def _recognize_by_attributes(garbage: Garbage, model) -> str:
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
encoded = attributes_to_floats(attributes)
return model.predict([encoded])[0]
def attributes_to_floats(attributes: list[str]) -> list[float]: def attributes_to_floats(attributes: list[str]) -> list[float]:
output: list[float] = [] output: list[float] = []
if attributes[0] == 'Longitiudonal': if attributes[0] == 'Longitiudonal':

View File

@ -36,30 +36,46 @@ def create_city() -> City:
streets = create_streets() streets = create_streets()
trashcans = create_trashcans() trashcans = create_trashcans()
bumps = create_speed_bumps() bumps = create_speed_bumps()
garbage_pieces = create_garbage_pieces() garbage_pieces = _craete_garbage_with_attributes()
garbage_pieces_counter = 0 garbage_pieces_counter = 0
for s in streets: for s in streets:
city.add_street(s) city.add_street(s)
for t in trashcans: for t in trashcans:
for i in range(4): for _ in range(4):
t.add_garbage(garbage_pieces[garbage_pieces_counter]) t.add_garbage(garbage_pieces[garbage_pieces_counter])
garbage_pieces_counter = garbage_pieces_counter + 1 garbage_pieces_counter = garbage_pieces_counter + 1
city.add_can(t) city.add_can(t)
garbage_pieces = _create_garbage_with_images()
garbage_pieces_counter = 0
for t in trashcans:
for _ in range(4):
t.add_garbage(garbage_pieces[garbage_pieces_counter])
garbage_pieces_counter = garbage_pieces_counter + 1
for b in bumps: for b in bumps:
city.add_bump(b) city.add_bump(b)
return city return city
def _craete_garbage_with_attributes() -> list[Garbage]:
def create_garbage_pieces() -> List[Garbage]:
garbage_pieces = [] garbage_pieces = []
with open('machine_learning/garbage_infill.csv', 'r') as file: with open('machine_learning/garbage_infill.csv', 'r') as file:
lines = file.readlines() lines = file.readlines()
for line in lines[1:]: for line in lines[1:]:
param = line.strip().split(',') param = line.strip().split(',')
garbage_pieces.append( garbage_pieces.append(
Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip())) Garbage(None, param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
return garbage_pieces return garbage_pieces
def _create_garbage_with_images() -> list[Garbage]:
garbage_pieces = []
current_path_number = 3014
for _ in range(0, 28):
path = 'machine_learning/garbage_photos/photos_not_from_train_set/IMG_' + str(current_path_number) + '.jpg'
new_garbage = Garbage(path, None, None, None, None, None, None, None, None)
garbage_pieces.append(new_garbage)
current_path_number = current_path_number + 1
if current_path_number == 3025:
current_path_number = current_path_number + 1
return garbage_pieces
def create_streets() -> List[Street]: def create_streets() -> List[Street]:
streets = [] streets = []