sztuczna_inteligencja_2023_.../movement.py
2023-06-05 10:17:28 +02:00

198 lines
7.2 KiB
Python

import joblib
from sklearn.calibration import LabelEncoder
from agentActionType import AgentActionType
import time
from garbage import Garbage, GarbageType, RecognizedGarbage
from garbageCan import GarbageCan
from machine_learning.neuron_network import Net
from turnCar import turn_left_orientation, turn_right_orientation
from garbageTruck import GarbageTruck
from typing import Tuple, Dict
from gridCellType import GridCellType
from gameContext import GameContext
from agentOrientation import AgentOrientation
import pygame
from bfs import find_path_to_nearest_can
from agentState import AgentState
import torch
import torchvision.transforms as transforms
from PIL import Image
def collect_garbage(game_context: GameContext) -> None:
while True:
start_agent_state = AgentState(game_context.dust_car.position, game_context.dust_car.orientation)
path = find_path_to_nearest_can(start_agent_state, game_context.grid, game_context.city)
if path == None or len(path) == 0:
break
move_dust_car(path, game_context)
next_position = calculate_next_position(game_context.dust_car)
game_context.grid[next_position] = GridCellType.VISITED_GARBAGE_CAN
can = game_context.city.cans_dict[next_position]
can.is_visited = True
_recognize_garbage(game_context.dust_car, can)
pass
def _recognize_garbage(dust_car: GarbageTruck, can: GarbageCan) -> None:
loaded_model = joblib.load('machine_learning/model.pkl')
checkpoint = torch.load('machine_learning/model.pt')
if 'module' in list(checkpoint['model_state_dict'].keys())[0]:
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
else:
checkpoint = checkpoint['model_state_dict']
neuron_model = Net()
neuron_model.load_state_dict(checkpoint)
neuron_model.eval()
for garbage in can.garbage:
predicted_class = _recognize_by_image(garbage, neuron_model) if garbage.img is not None else _recognize_by_attributes(garbage, loaded_model)
garbage_type: GarbageType = None
if predicted_class == 'PAPER':
garbage_type = GarbageType.PAPER
elif predicted_class == 'PLASTIC_AND_METAL':
garbage_type = GarbageType.PLASTIC_AND_METAL
elif garbage_type == 'GLASS':
garbage_type = GarbageType.GLASS
elif predicted_class == 'BIO' :
garbage_type = GarbageType.BIO
elif predicted_class == 'MIXED':
garbage_type = GarbageType.MIXED
print(predicted_class)
recognized_garbage = RecognizedGarbage(garbage, garbage_type)
dust_car.sort_garbage(recognized_garbage)
def _recognize_by_image(garbage: Garbage, model: Net) -> str:
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
image = Image.open(garbage.img)
image = transform(image)
with torch.no_grad():
output = model(image.unsqueeze(0))
_, predicted = torch.max(output.data, 1)
return _convert_image_prediction(predicted)
def _convert_image_prediction(prediction: torch.Tensor) -> str:
item = prediction.item()
if item == 0:
return 'BIO'
if item == 1:
return 'GLASS'
if item == 2:
return 'MIXED'
if item == 3:
return 'PAPER'
if item == 4:
return "PLASTIC_AND_METAL"
print(type(prediction))
return None
def _recognize_by_attributes(garbage: Garbage, model) -> str:
attributes = [garbage.shape, garbage.flexibility, garbage.does_smell, garbage.weight, garbage.size, garbage.color, garbage.softness, garbage.does_din]
encoded = attributes_to_floats(attributes)
return model.predict([encoded])[0]
def attributes_to_floats(attributes: list[str]) -> list[float]:
output: list[float] = []
if attributes[0] == 'Longitiudonal':
output.append(0)
elif attributes[0] == 'Round':
output.append(1)
elif attributes[0] == 'Flat':
output.append(2)
elif attributes[0] == 'Irregular':
output.append(3)
if attributes[1] == 'Low':
output.append(0)
elif attributes[1] == 'Medium':
output.append(1)
elif attributes[1] == 'High':
output.append(2)
if attributes[2] == "Yes":
output.append(0)
else:
output.append(1)
if attributes[3] == 'Low':
output.append(0)
elif attributes[3] == 'Medium':
output.append(1)
elif attributes[3] == 'High':
output.append(2)
if attributes[4] == 'Low':
output.append(0)
elif attributes[4] == 'Medium':
output.append(1)
elif attributes[4] == 'High':
output.append(2)
if attributes[5] == 'Transparent':
output.append(0)
elif attributes[5] == 'Light':
output.append(1)
elif attributes[5] == 'Dark':
output.append(2)
elif attributes[5] == "Colorful":
output.append(3)
if attributes[6] == 'Low':
output.append(0)
elif attributes[6] == 'Medium':
output.append(1)
elif attributes[6] == 'High':
output.append(2)
if attributes[7] == "Yes":
output.append(0)
else:
output.append(1)
return output
def move_dust_car(actions: list[AgentActionType], game_context: GameContext) -> None:
for action in actions:
street_position = game_context.dust_car.position
has_to_render_street = False
if action == AgentActionType.TURN_LEFT:
game_context.dust_car.orientation = turn_left_orientation(game_context.dust_car.orientation)
elif action == AgentActionType.TURN_RIGHT:
game_context.dust_car.orientation = turn_right_orientation(game_context.dust_car.orientation)
elif action == AgentActionType.MOVE_FORWARD:
game_context.dust_car.position = calculate_next_position(game_context.dust_car)
has_to_render_street = True
game_context.dust_car.render(game_context)
if has_to_render_street:
if game_context.grid[street_position] == GridCellType.STREET_HORIZONTAL:
game_context.render_in_cell(street_position, "imgs/street_horizontal.png")
elif game_context.grid[street_position] == GridCellType.STREET_VERTICAL:
game_context.render_in_cell(street_position, "imgs/street_vertical.png")
elif game_context.grid[street_position] == GridCellType.SPEED_BUMP:
game_context.render_in_cell(street_position, "imgs/speed_bump.png")
pygame.display.update()
time.sleep(0.15)
def calculate_next_position(car: GarbageTruck) -> Tuple[int, int]:
if car.orientation == AgentOrientation.UP:
if car.position[1] - 1 < 1:
return None
return (car.position[0], car.position[1] - 1)
if car.orientation == AgentOrientation.DOWN:
if car.position[1] + 1 > 27:
return None
return (car.position[0], car.position[1] + 1)
if car.orientation == AgentOrientation.LEFT:
if car.position[0] - 1 < 1:
return None
return (car.position[0] - 1, car.position[1])
if car.position[0] + 1 > 27:
return None
return (car.position[0] + 1, car.position[1])