recognize garbage by image
This commit is contained in:
parent
e4cd4e727c
commit
10f62ce529
54
movement.py
54
movement.py
@ -2,8 +2,9 @@ import joblib
|
|||||||
from sklearn.calibration import LabelEncoder
|
from sklearn.calibration import LabelEncoder
|
||||||
from agentActionType import AgentActionType
|
from agentActionType import AgentActionType
|
||||||
import time
|
import time
|
||||||
from garbage import GarbageType, RecognizedGarbage
|
from garbage import Garbage, GarbageType, RecognizedGarbage
|
||||||
from garbageCan import GarbageCan
|
from garbageCan import GarbageCan
|
||||||
|
from machine_learning.neuron_network import Net
|
||||||
from turnCar import turn_left_orientation, turn_right_orientation
|
from turnCar import turn_left_orientation, turn_right_orientation
|
||||||
from garbageTruck import GarbageTruck
|
from garbageTruck import GarbageTruck
|
||||||
from typing import Tuple, Dict
|
from typing import Tuple, Dict
|
||||||
@ -13,6 +14,9 @@ from agentOrientation import AgentOrientation
|
|||||||
import pygame
|
import pygame
|
||||||
from bfs import find_path_to_nearest_can
|
from bfs import find_path_to_nearest_can
|
||||||
from agentState import AgentState
|
from agentState import AgentState
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def collect_garbage(game_context: GameContext) -> None:
|
def collect_garbage(game_context: GameContext) -> None:
|
||||||
@ -31,10 +35,18 @@ def collect_garbage(game_context: GameContext) -> None:
|
|||||||
|
|
||||||
def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
|
def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
|
||||||
loaded_model = joblib.load('machine_learning/model.pkl')
|
loaded_model = joblib.load('machine_learning/model.pkl')
|
||||||
|
|
||||||
|
checkpoint = torch.load('machine_learning/model.pt')
|
||||||
|
if 'module' in list(checkpoint['model_state_dict'].keys())[0]:
|
||||||
|
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
|
||||||
|
else:
|
||||||
|
checkpoint = checkpoint['model_state_dict']
|
||||||
|
neuron_model = Net()
|
||||||
|
neuron_model.load_state_dict(checkpoint)
|
||||||
|
neuron_model.eval()
|
||||||
|
|
||||||
for garbage in can.garbage:
|
for garbage in can.garbage:
|
||||||
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
|
predicted_class = _recognize_by_image(garbage, neuron_model) if garbage.img is not None else _recognize_by_attributes(garbage, loaded_model)
|
||||||
encoded = attributes_to_floats(attributes)
|
|
||||||
predicted_class = loaded_model.predict([encoded])[0]
|
|
||||||
garbage_type: GarbageType = None
|
garbage_type: GarbageType = None
|
||||||
if predicted_class == 'PAPER':
|
if predicted_class == 'PAPER':
|
||||||
garbage_type = GarbageType.PAPER
|
garbage_type = GarbageType.PAPER
|
||||||
@ -50,6 +62,40 @@ def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
|
|||||||
recognized_garbage = RecognizedGarbage(garbage, garbage_type)
|
recognized_garbage = RecognizedGarbage(garbage, garbage_type)
|
||||||
dust_car.sort_garbage(recognized_garbage)
|
dust_car.sort_garbage(recognized_garbage)
|
||||||
|
|
||||||
|
def _recognize_by_image(garbage: Garbage, model: Net) -> str:
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((64, 64)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
])
|
||||||
|
image = Image.open(garbage.img)
|
||||||
|
image = transform(image)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(image.unsqueeze(0))
|
||||||
|
_, predicted = torch.max(output.data, 1)
|
||||||
|
return _convert_image_prediction(predicted)
|
||||||
|
|
||||||
|
def _convert_image_prediction(prediction: torch.Tensor) -> str:
|
||||||
|
item = prediction.item()
|
||||||
|
if item == 0:
|
||||||
|
return 'BIO'
|
||||||
|
if item == 1:
|
||||||
|
return 'GLASS'
|
||||||
|
if item == 2:
|
||||||
|
return 'MIXED'
|
||||||
|
if item == 3:
|
||||||
|
return 'PAPER'
|
||||||
|
if item == 4:
|
||||||
|
return "PLASTIC_AND_METAL"
|
||||||
|
print(type(prediction))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _recognize_by_attributes(garbage: Garbage, model) -> str:
|
||||||
|
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
|
||||||
|
encoded = attributes_to_floats(attributes)
|
||||||
|
return model.predict([encoded])[0]
|
||||||
|
|
||||||
|
|
||||||
def attributes_to_floats(attributes: list[str]) -> list[float]:
|
def attributes_to_floats(attributes: list[str]) -> list[float]:
|
||||||
output: list[float] = []
|
output: list[float] = []
|
||||||
if attributes[0] == 'Longitiudonal':
|
if attributes[0] == 'Longitiudonal':
|
||||||
|
28
startup.py
28
startup.py
@ -36,30 +36,46 @@ def create_city() -> City:
|
|||||||
streets = create_streets()
|
streets = create_streets()
|
||||||
trashcans = create_trashcans()
|
trashcans = create_trashcans()
|
||||||
bumps = create_speed_bumps()
|
bumps = create_speed_bumps()
|
||||||
garbage_pieces = create_garbage_pieces()
|
garbage_pieces = _craete_garbage_with_attributes()
|
||||||
garbage_pieces_counter = 0
|
garbage_pieces_counter = 0
|
||||||
for s in streets:
|
for s in streets:
|
||||||
city.add_street(s)
|
city.add_street(s)
|
||||||
for t in trashcans:
|
for t in trashcans:
|
||||||
for i in range(4):
|
for _ in range(4):
|
||||||
t.add_garbage(garbage_pieces[garbage_pieces_counter])
|
t.add_garbage(garbage_pieces[garbage_pieces_counter])
|
||||||
garbage_pieces_counter = garbage_pieces_counter + 1
|
garbage_pieces_counter = garbage_pieces_counter + 1
|
||||||
city.add_can(t)
|
city.add_can(t)
|
||||||
|
garbage_pieces = _create_garbage_with_images()
|
||||||
|
garbage_pieces_counter = 0
|
||||||
|
for t in trashcans:
|
||||||
|
for _ in range(4):
|
||||||
|
t.add_garbage(garbage_pieces[garbage_pieces_counter])
|
||||||
|
garbage_pieces_counter = garbage_pieces_counter + 1
|
||||||
for b in bumps:
|
for b in bumps:
|
||||||
city.add_bump(b)
|
city.add_bump(b)
|
||||||
return city
|
return city
|
||||||
|
|
||||||
|
def _craete_garbage_with_attributes() -> list[Garbage]:
|
||||||
def create_garbage_pieces() -> List[Garbage]:
|
|
||||||
garbage_pieces = []
|
garbage_pieces = []
|
||||||
with open('machine_learning/garbage_infill.csv', 'r') as file:
|
with open('machine_learning/garbage_infill.csv', 'r') as file:
|
||||||
lines = file.readlines()
|
lines = file.readlines()
|
||||||
for line in lines[1:]:
|
for line in lines[1:]:
|
||||||
param = line.strip().split(',')
|
param = line.strip().split(',')
|
||||||
garbage_pieces.append(
|
garbage_pieces.append(
|
||||||
Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
|
Garbage(None, param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
|
||||||
return garbage_pieces
|
return garbage_pieces
|
||||||
|
|
||||||
|
def _create_garbage_with_images() -> list[Garbage]:
|
||||||
|
garbage_pieces = []
|
||||||
|
current_path_number = 3014
|
||||||
|
for _ in range(0, 28):
|
||||||
|
path = 'machine_learning/garbage_photos/photos_not_from_train_set/IMG_' + str(current_path_number) + '.jpg'
|
||||||
|
new_garbage = Garbage(path, None, None, None, None, None, None, None, None)
|
||||||
|
garbage_pieces.append(new_garbage)
|
||||||
|
current_path_number = current_path_number + 1
|
||||||
|
if current_path_number == 3025:
|
||||||
|
current_path_number = current_path_number + 1
|
||||||
|
return garbage_pieces
|
||||||
|
|
||||||
def create_streets() -> List[Street]:
|
def create_streets() -> List[Street]:
|
||||||
streets = []
|
streets = []
|
||||||
|
Loading…
Reference in New Issue
Block a user