recognize garbage by image

This commit is contained in:
Pawel Felcyn 2023-06-05 10:17:28 +02:00
parent e4cd4e727c
commit 10f62ce529
2 changed files with 73 additions and 11 deletions

View File

@ -2,8 +2,9 @@ import joblib
from sklearn.calibration import LabelEncoder
from agentActionType import AgentActionType
import time
from garbage import GarbageType, RecognizedGarbage
from garbage import Garbage, GarbageType, RecognizedGarbage
from garbageCan import GarbageCan
from machine_learning.neuron_network import Net
from turnCar import turn_left_orientation, turn_right_orientation
from garbageTruck import GarbageTruck
from typing import Tuple, Dict
@ -13,6 +14,9 @@ from agentOrientation import AgentOrientation
import pygame
from bfs import find_path_to_nearest_can
from agentState import AgentState
import torch
import torchvision.transforms as transforms
from PIL import Image
def collect_garbage(game_context: GameContext) -> None:
@ -31,10 +35,18 @@ def collect_garbage(game_context: GameContext) -> None:
def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
loaded_model = joblib.load('machine_learning/model.pkl')
checkpoint = torch.load('machine_learning/model.pt')
if 'module' in list(checkpoint['model_state_dict'].keys())[0]:
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
else:
checkpoint = checkpoint['model_state_dict']
neuron_model = Net()
neuron_model.load_state_dict(checkpoint)
neuron_model.eval()
for garbage in can.garbage:
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
encoded = attributes_to_floats(attributes)
predicted_class = loaded_model.predict([encoded])[0]
predicted_class = _recognize_by_image(garbage, neuron_model) if garbage.img is not None else _recognize_by_attributes(garbage, loaded_model)
garbage_type: GarbageType = None
if predicted_class == 'PAPER':
garbage_type = GarbageType.PAPER
@ -50,6 +62,40 @@ def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
recognized_garbage = RecognizedGarbage(garbage, garbage_type)
dust_car.sort_garbage(recognized_garbage)
def _recognize_by_image(garbage: Garbage, model: Net) -> str:
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
image = Image.open(garbage.img)
image = transform(image)
with torch.no_grad():
output = model(image.unsqueeze(0))
_, predicted = torch.max(output.data, 1)
return _convert_image_prediction(predicted)
def _convert_image_prediction(prediction: torch.Tensor) -> str:
item = prediction.item()
if item == 0:
return 'BIO'
if item == 1:
return 'GLASS'
if item == 2:
return 'MIXED'
if item == 3:
return 'PAPER'
if item == 4:
return "PLASTIC_AND_METAL"
print(type(prediction))
return None
def _recognize_by_attributes(garbage: Garbage, model) -> str:
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
encoded = attributes_to_floats(attributes)
return model.predict([encoded])[0]
def attributes_to_floats(attributes: list[str]) -> list[float]:
output: list[float] = []
if attributes[0] == 'Longitiudonal':

View File

@ -36,30 +36,46 @@ def create_city() -> City:
streets = create_streets()
trashcans = create_trashcans()
bumps = create_speed_bumps()
garbage_pieces = create_garbage_pieces()
garbage_pieces = _craete_garbage_with_attributes()
garbage_pieces_counter = 0
for s in streets:
city.add_street(s)
for t in trashcans:
for i in range(4):
for _ in range(4):
t.add_garbage(garbage_pieces[garbage_pieces_counter])
garbage_pieces_counter = garbage_pieces_counter + 1
city.add_can(t)
garbage_pieces = _create_garbage_with_images()
garbage_pieces_counter = 0
for t in trashcans:
for _ in range(4):
t.add_garbage(garbage_pieces[garbage_pieces_counter])
garbage_pieces_counter = garbage_pieces_counter + 1
for b in bumps:
city.add_bump(b)
return city
def create_garbage_pieces() -> List[Garbage]:
def _craete_garbage_with_attributes() -> list[Garbage]:
garbage_pieces = []
with open('machine_learning/garbage_infill.csv', 'r') as file:
lines = file.readlines()
for line in lines[1:]:
param = line.strip().split(',')
garbage_pieces.append(
Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
return garbage_pieces
Garbage(None, param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
return garbage_pieces
def _create_garbage_with_images() -> list[Garbage]:
garbage_pieces = []
current_path_number = 3014
for _ in range(0, 28):
path = 'machine_learning/garbage_photos/photos_not_from_train_set/IMG_' + str(current_path_number) + '.jpg'
new_garbage = Garbage(path, None, None, None, None, None, None, None, None)
garbage_pieces.append(new_garbage)
current_path_number = current_path_number + 1
if current_path_number == 3025:
current_path_number = current_path_number + 1
return garbage_pieces
def create_streets() -> List[Street]:
streets = []