decisiontree with network
This commit is contained in:
parent
fc75709739
commit
1be5d0fe09
38
app.py
38
app.py
@ -11,10 +11,10 @@ import threading
|
||||
import time
|
||||
import random
|
||||
from classes.data.klient import Klient
|
||||
from classes.data.klient import Klient
|
||||
from classes.data.klient import KlientCechy
|
||||
import xml.etree.ElementTree as ET
|
||||
from decisiontree import predict_client
|
||||
|
||||
from classes.Jimmy_Neuron.predict_image import predict_image
|
||||
|
||||
pygame.init()
|
||||
window = pygame.display.set_mode((prefs.WIDTH, prefs.HEIGHT))
|
||||
@ -102,9 +102,8 @@ agent = Agent(prefs.SPAWN_POINT[0], prefs.SPAWN_POINT[1], cells)
|
||||
klient = Klient(prefs.GRID_SIZE-1, 17,cells)
|
||||
target_x, target_y = klientx_target-1, klienty_target
|
||||
|
||||
|
||||
|
||||
def watekDlaSciezkiAgenta():
|
||||
assigned = False
|
||||
time.sleep(3)
|
||||
while True:
|
||||
if len(path) > 0:
|
||||
@ -119,10 +118,26 @@ def watekDlaSciezkiAgenta():
|
||||
elif isinstance(element, tuple): # Check if it's a tuple indicating movement coordinates
|
||||
x, y = element
|
||||
agent.moveto(x, y)
|
||||
|
||||
neighbors = agent.get_neighbors(agent.current_cell, agent.cells)
|
||||
for neighbor in neighbors:
|
||||
if neighbor == klient.current_cell:
|
||||
if not assigned:
|
||||
random_client_data = random.choice(clients)
|
||||
glasses = predict_image(random_client_data.zdjecie)
|
||||
prediction = predict_client(random_client_data, glasses)
|
||||
print("\nClient data:")
|
||||
print(random_client_data)
|
||||
print("Prediction (Adult):", prediction)
|
||||
assigned = True
|
||||
break
|
||||
if assigned:
|
||||
break
|
||||
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
def watekDlaSciezkiKlienta():
|
||||
assigned = False
|
||||
time.sleep(3)
|
||||
while True:
|
||||
if len(path2) > 0:
|
||||
@ -138,20 +153,12 @@ def watekDlaSciezkiKlienta():
|
||||
x, y = element2
|
||||
klient.moveto(x, y)
|
||||
|
||||
if not assigned and klient.current_cell == cells[klientx_target][klienty_target]:
|
||||
if klient.current_cell == cells[klientx_target][klienty_target]:
|
||||
klient.przyStoliku = True
|
||||
klient.stolik = klient.current_cell
|
||||
random_client_data = random.choice(clients)
|
||||
prediction = predict_client(random_client_data)
|
||||
print("\nClient data:")
|
||||
print(random_client_data)
|
||||
print("Prediction (Adult):", prediction)
|
||||
assigned = True
|
||||
|
||||
if assigned:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
path2 = klient.bfs2(klientx_target, klienty_target)
|
||||
print("Najkrótsza ścieżka:", path2)
|
||||
@ -165,6 +172,7 @@ watek = threading.Thread(target=watekDlaSciezkiAgenta)
|
||||
watek.daemon = True
|
||||
watek.start()
|
||||
|
||||
|
||||
running = True
|
||||
while running:
|
||||
for event in pygame.event.get():
|
||||
|
36
classes/Jimmy_Neuron/predict_image.py
Normal file
36
classes/Jimmy_Neuron/predict_image.py
Normal file
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from torchvision.transforms import Compose, Lambda
|
||||
import torchvision.io as io
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
hidden_size = 135 * 64
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, 6, 5),
|
||||
torch.nn.MaxPool2d(2, 2),
|
||||
torch.nn.Conv2d(6, 16, 5),
|
||||
torch.nn.Flatten(),
|
||||
torch.nn.Linear(53824, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(hidden_size, 32 * 32),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(32 * 32, 10),
|
||||
torch.nn.LogSoftmax(dim=-1)
|
||||
).to(device)
|
||||
|
||||
model.load_state_dict(torch.load('model.pt', map_location=device))
|
||||
model.eval()
|
||||
|
||||
def predict_image(image_path):
|
||||
transform = Compose([Lambda(lambda x: x.float())])
|
||||
image = io.read_image(image_path, mode=io.ImageReadMode.UNCHANGED)
|
||||
image = transform(image)
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(device)
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
predicted_class = output.argmax(dim=1).item()
|
||||
print(predicted_class)
|
||||
|
||||
return predicted_class
|
||||
|
@ -275,6 +275,8 @@ class Agent:
|
||||
dx = abs(current[0] - target[0])
|
||||
dy = abs(current[1] - target[1])
|
||||
return dx + dy
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -85,8 +85,8 @@ for person in root.findall('person'):
|
||||
outfit_element = person.find('Outfit')
|
||||
outfit = outfit_element.text if outfit_element is not None else ''
|
||||
|
||||
glasses_element = person.find('Glasses')
|
||||
glasses = glasses_element.text if glasses_element is not None else ''
|
||||
image_element = person.find('Image')
|
||||
image = image_element.text if image_element is not None else ''
|
||||
|
||||
tattoo_element = person.find('Tattoo')
|
||||
tattoo = tattoo_element.text if tattoo_element is not None else ''
|
||||
@ -107,10 +107,10 @@ for person in root.findall('person'):
|
||||
'lysienie': balding,
|
||||
'broda': beard,
|
||||
'ubior': outfit,
|
||||
'okulary': glasses,
|
||||
'tatuaz': tattoo,
|
||||
'wlosy': hair,
|
||||
'zachowanie': behaviour
|
||||
'zachowanie': behaviour,
|
||||
'zdjecie': image
|
||||
}
|
||||
|
||||
clients.append(KlientCechy(**person_data))
|
||||
|
@ -14,7 +14,7 @@ class Klient:
|
||||
self.current_cell = cells[x][y]
|
||||
self.current_x = x
|
||||
self.current_y = y
|
||||
przyStoliku = False
|
||||
self.przyStoliku = False
|
||||
self.cells = cells
|
||||
self.X = x
|
||||
self.Y = y
|
||||
@ -183,7 +183,7 @@ class Klient:
|
||||
|
||||
|
||||
class KlientCechy:
|
||||
def __init__(self,imie,nazwisko,wiek,ulubiony_posilek,restrykcje_dietowe,zmarszczki, lysienie, broda, ubior, okulary, tatuaz, wlosy, zachowanie):
|
||||
def __init__(self,imie,nazwisko,wiek,ulubiony_posilek,restrykcje_dietowe,zmarszczki, lysienie, broda, ubior, tatuaz, wlosy, zachowanie, zdjecie):
|
||||
self.imie = imie
|
||||
self.nazwisko = nazwisko
|
||||
self.wiek = wiek
|
||||
@ -193,7 +193,7 @@ class KlientCechy:
|
||||
self.lysienie = lysienie
|
||||
self.broda = broda
|
||||
self.ubior = ubior
|
||||
self.okulary = okulary
|
||||
self.zdjecie = zdjecie
|
||||
self.tatuaz = tatuaz
|
||||
self.wlosy = wlosy
|
||||
self.zachowanie = zachowanie
|
||||
@ -210,4 +210,4 @@ class KlientCechy:
|
||||
print("Klient ma juz przypisany stolik.")
|
||||
|
||||
def __str__(self):
|
||||
return f"Klient: {self.imie} {self.nazwisko} {self.wiek}, ulubione Danie: {self.ulubiony_posilek}, restrykcje diet: {self.restrykcje_dietowe}. Jego cechy to: zmarszczki: {self.zmarszczki}, lysienie: {self.lysienie}, broda: {self.broda}, ubior: {self.ubior}, okulary: {self.okulary}, tatuaz: {self.tatuaz}, wlosy: {self.wlosy}, zachowanie: {self.zachowanie}"
|
||||
return f"Klient: {self.imie} {self.nazwisko} {self.wiek}, ulubione Danie: {self.ulubiony_posilek}, restrykcje diet: {self.restrykcje_dietowe}. Jego cechy to: zmarszczki: {self.zmarszczki}, lysienie: {self.lysienie}, broda: {self.broda}, ubior: {self.ubior}, tatuaz: {self.tatuaz}, wlosy: {self.wlosy}, zachowanie: {self.zachowanie}"
|
||||
|
@ -6,13 +6,14 @@
|
||||
<favoriteMeal>Tatar</favoriteMeal>
|
||||
<restrictions>Meat</restrictions>
|
||||
<Wrinkles>No</Wrinkles>
|
||||
<Balding>Yes</Balding>
|
||||
<Beard>Yes</Beard>
|
||||
<Balding>No</Balding>
|
||||
<Beard>No</Beard>
|
||||
<Outfit>Messy</Outfit>
|
||||
<Glasses>No</Glasses>
|
||||
<Glasses>Yes</Glasses>
|
||||
<Tattoo>No</Tattoo>
|
||||
<Hair>Color</Hair>
|
||||
<Hair>Natural</Hair>
|
||||
<Behaviour>Energetic</Behaviour>
|
||||
<Image>database\\clientsimg\\DavidBowie.png</Image>
|
||||
</person>
|
||||
<person>
|
||||
<name>Kamil</name>
|
||||
@ -23,11 +24,12 @@
|
||||
<Wrinkles>No</Wrinkles>
|
||||
<Balding>No</Balding>
|
||||
<Beard>No</Beard>
|
||||
<Outfit>Messy</Outfit>
|
||||
<Outfit>Casual</Outfit>
|
||||
<Glasses>No</Glasses>
|
||||
<Tattoo>No</Tattoo>
|
||||
<Hair>Color</Hair>
|
||||
<Hair>Natural</Hair>
|
||||
<Behaviour>Energetic</Behaviour>
|
||||
<Image>database\\clientsimg\\DavidBowie.png</Image>
|
||||
</person>
|
||||
<person>
|
||||
<name>Jon</name>
|
||||
@ -36,14 +38,15 @@
|
||||
<favoriteMeal>Grochówka</favoriteMeal>
|
||||
<restrictions>Vegan</restrictions>
|
||||
<Wrinkles>No</Wrinkles>
|
||||
<Balding>No</Balding>
|
||||
<Beard>No</Beard>
|
||||
<Balding>Yes</Balding>
|
||||
<Beard>Yes</Beard>
|
||||
<Outfit>Messy</Outfit>
|
||||
<Glasses>No</Glasses>
|
||||
<Tattoo>No</Tattoo>
|
||||
<Tattoo>Yes</Tattoo>
|
||||
<Hair>Color</Hair>
|
||||
<Behaviour>Energetic</Behaviour>
|
||||
</person>
|
||||
<Behaviour>Stressed</Behaviour>
|
||||
<Image>database\\clientsimg\\DavidBowie.png</Image>
|
||||
</person>
|
||||
<person>
|
||||
<name>Andrzej</name>
|
||||
<surname>Kowalski</surname>
|
||||
@ -52,11 +55,12 @@
|
||||
<restrictions>Meat</restrictions>
|
||||
<Wrinkles>No</Wrinkles>
|
||||
<Balding>No</Balding>
|
||||
<Beard>No</Beard>
|
||||
<Outfit>Messy</Outfit>
|
||||
<Glasses>No</Glasses>
|
||||
<Beard>Yes</Beard>
|
||||
<Outfit>Formal</Outfit>
|
||||
<Glasses>Yes</Glasses>
|
||||
<Tattoo>No</Tattoo>
|
||||
<Hair>Color</Hair>
|
||||
<Behaviour>Energetic</Behaviour>
|
||||
<Hair>Grey</Hair>
|
||||
<Behaviour>Calm</Behaviour>
|
||||
<Image>database\\clientsimg\\AndrzejKowalski.png</Image>
|
||||
</person>
|
||||
</people>
|
BIN
database/clientsimg/AndrzejKowalski.png
Normal file
BIN
database/clientsimg/AndrzejKowalski.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 26 KiB |
BIN
database/clientsimg/DavidBowie.png
Normal file
BIN
database/clientsimg/DavidBowie.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
BIN
database/clientsimg/JonSnow.png
Normal file
BIN
database/clientsimg/JonSnow.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
BIN
database/clientsimg/KamilStop.png
Normal file
BIN
database/clientsimg/KamilStop.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
@ -64,13 +64,13 @@ dot_data = tree.export_graphviz(clf_en, out_file=None,
|
||||
"Behaviour": random.choice(['Energetic', 'Stressed', 'Calm'])
|
||||
} """
|
||||
|
||||
def predict_client(client_data):
|
||||
def predict_client(client_data,glasses):
|
||||
new_client = {
|
||||
"Wrinkles": client_data.zmarszczki,
|
||||
"Balding": client_data.lysienie,
|
||||
"Beard": client_data.broda,
|
||||
"Outfit": client_data.ubior,
|
||||
"Glasses": client_data.okulary,
|
||||
"Glasses": 'Yes' if glasses == 1 else 'No',
|
||||
"Tattoo": client_data.tatuaz,
|
||||
"Hair": client_data.wlosy,
|
||||
"Behaviour": client_data.zachowanie
|
||||
|
Loading…
Reference in New Issue
Block a user