diff --git a/movement.py b/movement.py index e726f90..301a702 100644 --- a/movement.py +++ b/movement.py @@ -2,8 +2,9 @@ import joblib from sklearn.calibration import LabelEncoder from agentActionType import AgentActionType import time -from garbage import GarbageType, RecognizedGarbage +from garbage import Garbage, GarbageType, RecognizedGarbage from garbageCan import GarbageCan +from machine_learning.neuron_network import Net from turnCar import turn_left_orientation, turn_right_orientation from garbageTruck import GarbageTruck from typing import Tuple, Dict @@ -13,6 +14,9 @@ from agentOrientation import AgentOrientation import pygame from bfs import find_path_to_nearest_can from agentState import AgentState +import torch +import torchvision.transforms as transforms +from PIL import Image def collect_garbage(game_context: GameContext) -> None: @@ -31,10 +35,18 @@ def collect_garbage(game_context: GameContext) -> None: def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None: loaded_model = joblib.load('machine_learning/model.pkl') + + checkpoint = torch.load('machine_learning/model.pt') + if 'module' in list(checkpoint['model_state_dict'].keys())[0]: + checkpoint = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()} + else: + checkpoint = checkpoint['model_state_dict'] + neuron_model = Net() + neuron_model.load_state_dict(checkpoint) + neuron_model.eval() + for garbage in can.garbage: - attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din] - encoded = attributes_to_floats(attributes) - predicted_class = loaded_model.predict([encoded])[0] + predicted_class = _recognize_by_image(garbage, neuron_model) if garbage.img is not None else _recognize_by_attributes(garbage, loaded_model) garbage_type: GarbageType = None if predicted_class == 'PAPER': garbage_type = GarbageType.PAPER @@ -50,6 +62,40 @@ def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None: recognized_garbage = RecognizedGarbage(garbage, garbage_type) dust_car.sort_garbage(recognized_garbage) +def _recognize_by_image(garbage: Garbage, model: Net) -> str: + transform = transforms.Compose([ + transforms.Resize((64, 64)), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + image = Image.open(garbage.img) + image = transform(image) + with torch.no_grad(): + output = model(image.unsqueeze(0)) + _, predicted = torch.max(output.data, 1) + return _convert_image_prediction(predicted) + +def _convert_image_prediction(prediction: torch.Tensor) -> str: + item = prediction.item() + if item == 0: + return 'BIO' + if item == 1: + return 'GLASS' + if item == 2: + return 'MIXED' + if item == 3: + return 'PAPER' + if item == 4: + return "PLASTIC_AND_METAL" + print(type(prediction)) + return None + +def _recognize_by_attributes(garbage: Garbage, model) -> str: + attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din] + encoded = attributes_to_floats(attributes) + return model.predict([encoded])[0] + + def attributes_to_floats(attributes: list[str]) -> list[float]: output: list[float] = [] if attributes[0] == 'Longitiudonal': diff --git a/startup.py b/startup.py index 6cc7222..1bd9a86 100644 --- a/startup.py +++ b/startup.py @@ -36,30 +36,46 @@ def create_city() -> City: streets = create_streets() trashcans = create_trashcans() bumps = create_speed_bumps() - garbage_pieces = create_garbage_pieces() + garbage_pieces = _craete_garbage_with_attributes() garbage_pieces_counter = 0 for s in streets: city.add_street(s) for t in trashcans: - for i in range(4): + for _ in range(4): t.add_garbage(garbage_pieces[garbage_pieces_counter]) garbage_pieces_counter = garbage_pieces_counter + 1 city.add_can(t) + garbage_pieces = _create_garbage_with_images() + garbage_pieces_counter = 0 + for t in trashcans: + for _ in range(4): + t.add_garbage(garbage_pieces[garbage_pieces_counter]) + garbage_pieces_counter = garbage_pieces_counter + 1 for b in bumps: city.add_bump(b) return city - -def create_garbage_pieces() -> List[Garbage]: +def _craete_garbage_with_attributes() -> list[Garbage]: garbage_pieces = [] with open('machine_learning/garbage_infill.csv', 'r') as file: lines = file.readlines() for line in lines[1:]: param = line.strip().split(',') garbage_pieces.append( - Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip())) - return garbage_pieces - + Garbage(None, param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip())) + return garbage_pieces + +def _create_garbage_with_images() -> list[Garbage]: + garbage_pieces = [] + current_path_number = 3014 + for _ in range(0, 28): + path = 'machine_learning/garbage_photos/photos_not_from_train_set/IMG_' + str(current_path_number) + '.jpg' + new_garbage = Garbage(path, None, None, None, None, None, None, None, None) + garbage_pieces.append(new_garbage) + current_path_number = current_path_number + 1 + if current_path_number == 3025: + current_path_number = current_path_number + 1 + return garbage_pieces def create_streets() -> List[Street]: streets = []