recognize garbage by image
This commit is contained in:
parent
e4cd4e727c
commit
10f62ce529
54
movement.py
54
movement.py
@ -2,8 +2,9 @@ import joblib
|
||||
from sklearn.calibration import LabelEncoder
|
||||
from agentActionType import AgentActionType
|
||||
import time
|
||||
from garbage import GarbageType, RecognizedGarbage
|
||||
from garbage import Garbage, GarbageType, RecognizedGarbage
|
||||
from garbageCan import GarbageCan
|
||||
from machine_learning.neuron_network import Net
|
||||
from turnCar import turn_left_orientation, turn_right_orientation
|
||||
from garbageTruck import GarbageTruck
|
||||
from typing import Tuple, Dict
|
||||
@ -13,6 +14,9 @@ from agentOrientation import AgentOrientation
|
||||
import pygame
|
||||
from bfs import find_path_to_nearest_can
|
||||
from agentState import AgentState
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def collect_garbage(game_context: GameContext) -> None:
|
||||
@ -31,10 +35,18 @@ def collect_garbage(game_context: GameContext) -> None:
|
||||
|
||||
def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
|
||||
loaded_model = joblib.load('machine_learning/model.pkl')
|
||||
|
||||
checkpoint = torch.load('machine_learning/model.pt')
|
||||
if 'module' in list(checkpoint['model_state_dict'].keys())[0]:
|
||||
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
|
||||
else:
|
||||
checkpoint = checkpoint['model_state_dict']
|
||||
neuron_model = Net()
|
||||
neuron_model.load_state_dict(checkpoint)
|
||||
neuron_model.eval()
|
||||
|
||||
for garbage in can.garbage:
|
||||
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
|
||||
encoded = attributes_to_floats(attributes)
|
||||
predicted_class = loaded_model.predict([encoded])[0]
|
||||
predicted_class = _recognize_by_image(garbage, neuron_model) if garbage.img is not None else _recognize_by_attributes(garbage, loaded_model)
|
||||
garbage_type: GarbageType = None
|
||||
if predicted_class == 'PAPER':
|
||||
garbage_type = GarbageType.PAPER
|
||||
@ -50,6 +62,40 @@ def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
|
||||
recognized_garbage = RecognizedGarbage(garbage, garbage_type)
|
||||
dust_car.sort_garbage(recognized_garbage)
|
||||
|
||||
def _recognize_by_image(garbage: Garbage, model: Net) -> str:
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
image = Image.open(garbage.img)
|
||||
image = transform(image)
|
||||
with torch.no_grad():
|
||||
output = model(image.unsqueeze(0))
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
return _convert_image_prediction(predicted)
|
||||
|
||||
def _convert_image_prediction(prediction: torch.Tensor) -> str:
|
||||
item = prediction.item()
|
||||
if item == 0:
|
||||
return 'BIO'
|
||||
if item == 1:
|
||||
return 'GLASS'
|
||||
if item == 2:
|
||||
return 'MIXED'
|
||||
if item == 3:
|
||||
return 'PAPER'
|
||||
if item == 4:
|
||||
return "PLASTIC_AND_METAL"
|
||||
print(type(prediction))
|
||||
return None
|
||||
|
||||
def _recognize_by_attributes(garbage: Garbage, model) -> str:
|
||||
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
|
||||
encoded = attributes_to_floats(attributes)
|
||||
return model.predict([encoded])[0]
|
||||
|
||||
|
||||
def attributes_to_floats(attributes: list[str]) -> list[float]:
|
||||
output: list[float] = []
|
||||
if attributes[0] == 'Longitiudonal':
|
||||
|
28
startup.py
28
startup.py
@ -36,30 +36,46 @@ def create_city() -> City:
|
||||
streets = create_streets()
|
||||
trashcans = create_trashcans()
|
||||
bumps = create_speed_bumps()
|
||||
garbage_pieces = create_garbage_pieces()
|
||||
garbage_pieces = _craete_garbage_with_attributes()
|
||||
garbage_pieces_counter = 0
|
||||
for s in streets:
|
||||
city.add_street(s)
|
||||
for t in trashcans:
|
||||
for i in range(4):
|
||||
for _ in range(4):
|
||||
t.add_garbage(garbage_pieces[garbage_pieces_counter])
|
||||
garbage_pieces_counter = garbage_pieces_counter + 1
|
||||
city.add_can(t)
|
||||
garbage_pieces = _create_garbage_with_images()
|
||||
garbage_pieces_counter = 0
|
||||
for t in trashcans:
|
||||
for _ in range(4):
|
||||
t.add_garbage(garbage_pieces[garbage_pieces_counter])
|
||||
garbage_pieces_counter = garbage_pieces_counter + 1
|
||||
for b in bumps:
|
||||
city.add_bump(b)
|
||||
return city
|
||||
|
||||
|
||||
def create_garbage_pieces() -> List[Garbage]:
|
||||
def _craete_garbage_with_attributes() -> list[Garbage]:
|
||||
garbage_pieces = []
|
||||
with open('machine_learning/garbage_infill.csv', 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines[1:]:
|
||||
param = line.strip().split(',')
|
||||
garbage_pieces.append(
|
||||
Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
|
||||
return garbage_pieces
|
||||
Garbage(None, param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
|
||||
return garbage_pieces
|
||||
|
||||
def _create_garbage_with_images() -> list[Garbage]:
|
||||
garbage_pieces = []
|
||||
current_path_number = 3014
|
||||
for _ in range(0, 28):
|
||||
path = 'machine_learning/garbage_photos/photos_not_from_train_set/IMG_' + str(current_path_number) + '.jpg'
|
||||
new_garbage = Garbage(path, None, None, None, None, None, None, None, None)
|
||||
garbage_pieces.append(new_garbage)
|
||||
current_path_number = current_path_number + 1
|
||||
if current_path_number == 3025:
|
||||
current_path_number = current_path_number + 1
|
||||
return garbage_pieces
|
||||
|
||||
def create_streets() -> List[Street]:
|
||||
streets = []
|
||||
|
Loading…
Reference in New Issue
Block a user