wozek/ai-wozek/wozek.py

498 lines
17 KiB
Python
Raw Normal View History

2024-06-17 04:58:21 +02:00
import pygame
import sys
import random
import os
import time
from collections import deque
import heapq
import torch
import classes
from classes import *
import numpy as np
import pandas as pd
import torchvision.transforms as transforms
import math
from PIL import Image
class Node():
def __init__(self,position,rotation,action,parent,cost):
self.position=position
self.rotation=rotation
self.action=action
self.parent=parent
self.cost=cost
def __lt__(self, other):
return (self.cost < other.cost)
def __le__(self, other):
return (self.cost <= other.cost)
# Initialize Pygame
pygame.init()
# Constants
TILE_SIZE = 49 # Size of a square tile in pixels
GRID_WIDTH, GRID_HEIGHT = 16,16 # Grid dimensions
SCREEN_WIDTH, SCREEN_HEIGHT = GRID_WIDTH * TILE_SIZE, GRID_HEIGHT * TILE_SIZE
FPS = 60 # Frames per second
# Setup display
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption('Forklift Game')
# Clock
clock = pygame.time.Clock()
# Function to load and scale images
def load_image(name, scale=None):
"""Loads an image and optionally scales it."""
image = pygame.image.load(name).convert_alpha()
if scale:
image = pygame.transform.scale(image, scale)
return image
# Placeholder for images (will be loaded after video mode set)
forklift_image_full = None
freight_images_full = None
# Game variables
forklift_pos = [0, 0]
rotation='N'# Adjusted starting position of the forklift
carrying_freight = False
carried_freight = None
current_freight= set()
freight_content = []
freight_positions= {}# Dictionary to keep track of freight positions and types
tile_cost={}
tile_cost[(8,0)]=10
tile_cost[(7,1)]=10
tile_cost[(6,1)]=3
tile_cost[(5,0)]=10
model=torch.jit.load('./siec/model.pt')
model.eval()
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
# Load images
def load_images():
global forklift_image_full, freight_images_full
global rotation
if rotation=='E':
forklift_image_full = load_image('forkliftE.png', (TILE_SIZE, TILE_SIZE))
elif rotation=='W':
forklift_image_full = load_image('forkliftW.png', (TILE_SIZE, TILE_SIZE))
elif rotation=='N':
forklift_image_full = load_image('forkliftS.png', (TILE_SIZE, TILE_SIZE))
elif rotation=='S':
forklift_image_full = load_image('forkliftN.png', (TILE_SIZE, TILE_SIZE))
#forklift_image_full = load_image('forklift.png', (TILE_SIZE, TILE_SIZE))
# Initialize or reset game elements
def init_game():
freight_positions.clear() # Ensure images are loaded after video mode set
reset_truck_bed_freight()
draw_freight()
load_images()
# Reset freight on the truck bed
def reset_truck_bed_freight():
global current_freight
for x in range(12, 16):
height=random.choice(['small','medium','big'])
width = random.choice(['small', 'medium', 'big'])
depth = random.choice(['small', 'medium', 'big'])
weight=random.choice(['light','medium','medium-heavy','heavy'])
damage = random.choice(['yes','no'])
label_state = random.choice(['yes', 'no'])
content=random.choice(['fruits','clothes','car_parts','nuclear_waste'])
value=random.choice(['cheap','expensive'])
position=[x,0]
cargo=classes.Cargo(height,width,depth,weight,damage,label_state,content,value,position)
cargo.contentSplit()
current_freight.add((x,0))
freight_content.append(cargo)
freight_positions[(x,0)]=cargo
# Drawing functions
def draw_board():
screen.fill((255, 255, 255))
for x in range(GRID_WIDTH):
for y in range(GRID_HEIGHT):
pygame.draw.rect(screen, (0, 0, 0), pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE), 1)
def draw_truck_bed_and_racks():
for y in range(10, 16):
for x in range(0,6):
pygame.draw.rect(screen, (165, 42, 42), (x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE))
for y in range(4, 10):
for x in range(0,6):
pygame.draw.rect(screen, (255, 0, 255), (x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE))
for y in range(10, 16):
for x in range(10,16):
pygame.draw.rect(screen, (191, 255, 0), (x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE))
for y in range(4, 10):
for x in range(10,16):
pygame.draw.rect(screen, (0, 255, 255), (x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE))
for key in tile_cost:
x=key[0]
y=key[1]
pygame.draw.rect(screen, (10*tile_cost[key], 130, 100), (x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE))
def draw_forklift_and_freight():
x, y = forklift_pos
if carrying_freight:
# Draw smaller images when carrying freight
small_size = (TILE_SIZE // 2, TILE_SIZE // 2)
forklift_small = pygame.transform.scale(forklift_image_full, small_size)
freight_small = pygame.transform.scale(load_image(carried_freight.image,(TILE_SIZE,TILE_SIZE)), small_size)
screen.blit(forklift_small, (x * TILE_SIZE, y * TILE_SIZE + TILE_SIZE // 2))
screen.blit(freight_small, (x * TILE_SIZE + TILE_SIZE // 2, y * TILE_SIZE))
else:
screen.blit(forklift_image_full, (x * TILE_SIZE, y * TILE_SIZE))
def draw_freight():
for key in freight_positions.keys():
screen.blit(load_image(freight_positions[key].image,(TILE_SIZE,TILE_SIZE)), (key[0] * TILE_SIZE, key[1] * TILE_SIZE))
#for item in freight_positions:
#screen.blit(load_image(item.image,(TILE_SIZE,TILE_SIZE)),(item.position[0]*TILE_SIZE,item.position[1]*TILE_SIZE))
# Game mechanics
def move_forklift():
global forklift_pos
if(rotation=='E'):
new_pos=[forklift_pos[0]+1,forklift_pos[1]]
elif(rotation=='W'):
new_pos=[forklift_pos[0]-1,forklift_pos[1]]
elif(rotation=='N'):
new_pos=[forklift_pos[0],forklift_pos[1]-1]
elif (rotation == 'S'):
new_pos = [forklift_pos[0], forklift_pos[1] + 1]
#new_pos = [forklift_pos[0] + dx, forklift_pos[1] + dy]
if 0 <= new_pos[0] < GRID_WIDTH and 0 <= new_pos[1] < GRID_HEIGHT:
forklift_pos = new_pos
def rotate_forklift(x):
global rotation
rot=['N','E','S','W']
rots=rot.index(rotation)
if x=='L':
if rots==0:
x=rot[3]
else:
x=rot[rots-1]
elif x=='R':
if rots==3:
x=rot[0]
else:
x=rot[rots+1]
rotation=x
free_lime={(x,y) for x in range(10,16) for y in range(10,16)}
free_cyan={(x,y) for x in range(10,16) for y in range(4,10)}
free_pink={(x,y) for x in range(0,6) for y in range(4,10)}
free_red={(x,y) for x in range(0,6) for y in range(10,16)}
def handle_freight():
global carrying_freight, carried_freight, freight_positions
pos_tuple = tuple(forklift_pos)
if carrying_freight:
if pos_tuple not in freight_positions:
freight_positions[pos_tuple] = carried_freight
carrying_freight = False
carried_freight = None
'''if pos_tuple in free_lime:
free_lime.remove(pos_tuple)
elif pos_tuple in free_red:
free_red.remove(pos_tuple)
elif pos_tuple in free_cyan:
free_cyan.remove(pos_tuple)
elif pos_tuple in free_pink:
free_pink.remove(pos_tuple)'''
else:
if pos_tuple in freight_positions:
carried_freight=freight_positions.pop(pos_tuple)
carrying_freight = True
#current_freight.discard(pos_tuple)
#searching for successors
def succ(current_node):
current_rotation=current_node.rotation
x=current_node.position[0]
y=current_node.position[1]
current_cost=tile_cost.get((x,y),1)
successors=[]
if(current_rotation=="N"):
if(y>0):
pos=[]
pos.append(x)
pos.append(y-1)
action='FW'
successor=Node(pos,current_rotation,action,current_node,current_cost)
successors.append(successor)
if(x>0):
pos = []
pos.append(x)
pos.append(y)
new_rotation='W'
action='L'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
if(x<15):
pos = []
pos.append(x)
pos.append(y)
new_rotation='E'
action='R'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
elif (current_rotation == "S"):
if (y < 15):
pos = []
pos.append(x)
pos.append(y + 1)
action = 'FW'
successor = Node(pos, current_rotation,action,current_node,current_cost)
successors.append(successor)
if (x <15):
pos = []
pos.append(x)
pos.append(y)
new_rotation = 'E'
action='L'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
if (x > 0):
pos = []
pos.append(x)
pos.append(y)
new_rotation = 'W'
action = 'R'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
elif (current_rotation == "E"):
if (x <15):
pos = []
pos.append(x+1)
pos.append(y)
action = 'FW'
successor = Node(pos, current_rotation,action,current_node,current_cost)
successors.append(successor)
if (y <15):
pos = []
pos.append(x)
pos.append(y)
new_rotation = 'S'
action='R'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
if (y >0):
pos = []
pos.append(x)
pos.append(y)
new_rotation = 'N'
action = 'L'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
elif (current_rotation == "W"):
if (x > 0):
pos = []
pos.append(x-1)
pos.append(y)
action = 'FW'
successor = Node(pos, current_rotation,action,current_node,current_cost)
successors.append(successor)
if (y >0):
pos = []
pos.append(x)
pos.append(y)
new_rotation = 'N'
action='R'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
if (y <15):
pos = []
pos.append(x)
pos.append(y)
new_rotation = 'S'
action = 'L'
successor = Node(pos, new_rotation,action,current_node,current_cost)
successors.append(successor)
return successors
def preprocess_image(image_path, transform):
image = Image.open(image_path).convert("RGB")
return image, transform(image).unsqueeze(0)
def distance(current_node,target):
return abs(current_node.position[0]-target.position[0])+abs(current_node.position[1]-target.position[1])
#bfs
def bfs(isstate,final):
fringe=deque()
fringe.append(isstate)
path=[]
explored=[]
while(True):
if(len(fringe)==0):
return False
node=fringe.popleft()
if(node.position[0]==final.position[0] and node.position[1]==final.position[1]):
while(node.parent!=None):
path.append(node)
node=node.parent
return path
explored.append(node)
successors=succ(node)
for successor in successors:
if (successor not in fringe and successor not in explored):
fringe.append(successor)
def astar(isstate,final):
fringe=[]
heapq.heappush(fringe,(0,isstate))
path = []
explored = []
total_cost={isstate:0}
while(True):
if (len(fringe) == 0):
return False
a,node =heapq.heappop(fringe)
if (node.position[0] == final.position[0] and node.position[1] == final.position[1]):
while (node.parent != None):
path.append(node)
node = node.parent
return path
explored.append(node)
successors = succ(node)
for successor in successors:
new_cost=total_cost[node]+successor.cost
if (successor not in explored or new_cost<total_cost.get(successor,float('inf'))):
total_cost[successor]=new_cost
p=new_cost+distance(successor,final)
heapq.heappush(fringe,(p,successor))
#drzewko
tree_data_base = pd.read_csv('paczki.csv')
def entropy(data):
labels = data.iloc[:, -1] # Ostatnia kolumna zawiera etykiety klas i pomija 1 wiersz bo jest tytulowy
counts = labels.value_counts() #tu zlicza wszystkie opcje
probabilities = counts / len(labels)
entropy = -sum(probabilities * np.log2(probabilities))
return entropy
def information_gain(data, attribute):
total_entropy = entropy(data)
values = data[attribute].unique() #przypisujemy wszystkie opcje danego atrybutu np wyoski/niski/sredni
weighted_entropy = 0
for value in values:
subset = data[data[attribute] == value] # przypisujesz wszystkie wiersze danego value do subset
subset_entropy = entropy(subset)
weighted_entropy += (len(subset) / len(data)) * subset_entropy
return (total_entropy - weighted_entropy)
# Main game loop
def game_loop():
init_game()
current=Node(forklift_pos,rotation,'start',None,0)
dest=current_freight.pop()
final=Node([dest[0],dest[1]],'N','final',None,0)
path=astar(current,final)
path.reverse()
for node in path:
print(node.action)
i=0
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_LEFT:
rotate_forklift('L')
load_images()
elif event.key == pygame.K_RIGHT:
rotate_forklift('R')
load_images()
elif event.key == pygame.K_UP:
move_forklift()
elif event.key == pygame.K_SPACE:
handle_freight()
elif event.key == pygame.K_r:
reset_truck_bed_freight()
draw_board()
draw_truck_bed_and_racks()
draw_freight()
draw_forklift_and_freight()
pygame.display.flip()
clock.tick(FPS)
if(len(current_freight)==0):
reset_truck_bed_freight()
if(forklift_pos==final.position):
handle_freight()
destination=[]
if carrying_freight:
if(carried_freight.content!='car_parts'):
img=carried_freight.image
image=Image.open(img).convert("RGB")
tensor=transform(image).unsqueeze(0)
with torch.no_grad():
tensor = tensor.to(device)
outputs = model(tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
score=probabilities.cpu().numpy().flatten()
if(score[0]>score[1] and score[0]>score[2]):
print("flammable")
destination=free_lime.pop()
elif(score[1]>score[0] and score[1]>score[2]):
print("fragile")
destination=free_pink.pop()
elif(score[2]>score[0] and score[2]>score[1]):
print("toxic")
destination=free_red.pop()
else:
destination=free_cyan.pop()
#return probabilities.cpu().numpy().flatten()
print(destination)
final=Node([destination[0],destination[1]],'N','final',None,0)
else:
destination=current_freight.pop()
print(destination)
final = Node([destination[0], destination[1]], 'N', 'final', None, 0)
curr=Node(forklift_pos,rotation,'start',None,0)
path=astar(curr,final)
path.reverse()
print(path)
i=0
if(i<len(path)):
nod=path[i]
if(nod.action=='FW'):
move_forklift()
elif(nod.action=='L'):
rotate_forklift('L')
load_images()
elif(nod.action=='R'):
rotate_forklift('R')
load_images()
i=i+1
pygame.time.wait(500)
pygame.quit()
sys.exit()
game_loop()