decisiontree with network

This commit is contained in:
Kosterix08 2024-06-12 00:24:57 +02:00
parent fc75709739
commit 1be5d0fe09
11 changed files with 91 additions and 41 deletions

38
app.py
View File

@ -11,10 +11,10 @@ import threading
import time
import random
from classes.data.klient import Klient
from classes.data.klient import Klient
from classes.data.klient import KlientCechy
import xml.etree.ElementTree as ET
from decisiontree import predict_client
from classes.Jimmy_Neuron.predict_image import predict_image
pygame.init()
window = pygame.display.set_mode((prefs.WIDTH, prefs.HEIGHT))
@ -102,9 +102,8 @@ agent = Agent(prefs.SPAWN_POINT[0], prefs.SPAWN_POINT[1], cells)
klient = Klient(prefs.GRID_SIZE-1, 17,cells)
target_x, target_y = klientx_target-1, klienty_target
def watekDlaSciezkiAgenta():
assigned = False
time.sleep(3)
while True:
if len(path) > 0:
@ -119,10 +118,26 @@ def watekDlaSciezkiAgenta():
elif isinstance(element, tuple): # Check if it's a tuple indicating movement coordinates
x, y = element
agent.moveto(x, y)
neighbors = agent.get_neighbors(agent.current_cell, agent.cells)
for neighbor in neighbors:
if neighbor == klient.current_cell:
if not assigned:
random_client_data = random.choice(clients)
glasses = predict_image(random_client_data.zdjecie)
prediction = predict_client(random_client_data, glasses)
print("\nClient data:")
print(random_client_data)
print("Prediction (Adult):", prediction)
assigned = True
break
if assigned:
break
time.sleep(1)
def watekDlaSciezkiKlienta():
assigned = False
time.sleep(3)
while True:
if len(path2) > 0:
@ -138,20 +153,12 @@ def watekDlaSciezkiKlienta():
x, y = element2
klient.moveto(x, y)
if not assigned and klient.current_cell == cells[klientx_target][klienty_target]:
if klient.current_cell == cells[klientx_target][klienty_target]:
klient.przyStoliku = True
klient.stolik = klient.current_cell
random_client_data = random.choice(clients)
prediction = predict_client(random_client_data)
print("\nClient data:")
print(random_client_data)
print("Prediction (Adult):", prediction)
assigned = True
if assigned:
break
time.sleep(1)
path2 = klient.bfs2(klientx_target, klienty_target)
print("Najkrótsza ścieżka:", path2)
@ -165,6 +172,7 @@ watek = threading.Thread(target=watekDlaSciezkiAgenta)
watek.daemon = True
watek.start()
running = True
while running:
for event in pygame.event.get():

View File

@ -0,0 +1,36 @@
import torch
from torchvision.transforms import Compose, Lambda
import torchvision.io as io
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
hidden_size = 135 * 64
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 6, 5),
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(6, 16, 5),
torch.nn.Flatten(),
torch.nn.Linear(53824, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 32 * 32),
torch.nn.ReLU(),
torch.nn.Linear(32 * 32, 10),
torch.nn.LogSoftmax(dim=-1)
).to(device)
model.load_state_dict(torch.load('model.pt', map_location=device))
model.eval()
def predict_image(image_path):
transform = Compose([Lambda(lambda x: x.float())])
image = io.read_image(image_path, mode=io.ImageReadMode.UNCHANGED)
image = transform(image)
image = image.unsqueeze(0)
image = image.to(device)
with torch.no_grad():
output = model(image)
predicted_class = output.argmax(dim=1).item()
print(predicted_class)
return predicted_class

View File

@ -275,6 +275,8 @@ class Agent:
dx = abs(current[0] - target[0])
dy = abs(current[1] - target[1])
return dx + dy

View File

@ -85,8 +85,8 @@ for person in root.findall('person'):
outfit_element = person.find('Outfit')
outfit = outfit_element.text if outfit_element is not None else ''
glasses_element = person.find('Glasses')
glasses = glasses_element.text if glasses_element is not None else ''
image_element = person.find('Image')
image = image_element.text if image_element is not None else ''
tattoo_element = person.find('Tattoo')
tattoo = tattoo_element.text if tattoo_element is not None else ''
@ -107,10 +107,10 @@ for person in root.findall('person'):
'lysienie': balding,
'broda': beard,
'ubior': outfit,
'okulary': glasses,
'tatuaz': tattoo,
'wlosy': hair,
'zachowanie': behaviour
'zachowanie': behaviour,
'zdjecie': image
}
clients.append(KlientCechy(**person_data))

View File

@ -14,7 +14,7 @@ class Klient:
self.current_cell = cells[x][y]
self.current_x = x
self.current_y = y
przyStoliku = False
self.przyStoliku = False
self.cells = cells
self.X = x
self.Y = y
@ -183,7 +183,7 @@ class Klient:
class KlientCechy:
def __init__(self,imie,nazwisko,wiek,ulubiony_posilek,restrykcje_dietowe,zmarszczki, lysienie, broda, ubior, okulary, tatuaz, wlosy, zachowanie):
def __init__(self,imie,nazwisko,wiek,ulubiony_posilek,restrykcje_dietowe,zmarszczki, lysienie, broda, ubior, tatuaz, wlosy, zachowanie, zdjecie):
self.imie = imie
self.nazwisko = nazwisko
self.wiek = wiek
@ -193,7 +193,7 @@ class KlientCechy:
self.lysienie = lysienie
self.broda = broda
self.ubior = ubior
self.okulary = okulary
self.zdjecie = zdjecie
self.tatuaz = tatuaz
self.wlosy = wlosy
self.zachowanie = zachowanie
@ -210,4 +210,4 @@ class KlientCechy:
print("Klient ma juz przypisany stolik.")
def __str__(self):
return f"Klient: {self.imie} {self.nazwisko} {self.wiek}, ulubione Danie: {self.ulubiony_posilek}, restrykcje diet: {self.restrykcje_dietowe}. Jego cechy to: zmarszczki: {self.zmarszczki}, lysienie: {self.lysienie}, broda: {self.broda}, ubior: {self.ubior}, okulary: {self.okulary}, tatuaz: {self.tatuaz}, wlosy: {self.wlosy}, zachowanie: {self.zachowanie}"
return f"Klient: {self.imie} {self.nazwisko} {self.wiek}, ulubione Danie: {self.ulubiony_posilek}, restrykcje diet: {self.restrykcje_dietowe}. Jego cechy to: zmarszczki: {self.zmarszczki}, lysienie: {self.lysienie}, broda: {self.broda}, ubior: {self.ubior}, tatuaz: {self.tatuaz}, wlosy: {self.wlosy}, zachowanie: {self.zachowanie}"

View File

@ -6,13 +6,14 @@
<favoriteMeal>Tatar</favoriteMeal>
<restrictions>Meat</restrictions>
<Wrinkles>No</Wrinkles>
<Balding>Yes</Balding>
<Beard>Yes</Beard>
<Balding>No</Balding>
<Beard>No</Beard>
<Outfit>Messy</Outfit>
<Glasses>No</Glasses>
<Glasses>Yes</Glasses>
<Tattoo>No</Tattoo>
<Hair>Color</Hair>
<Hair>Natural</Hair>
<Behaviour>Energetic</Behaviour>
<Image>database\\clientsimg\\DavidBowie.png</Image>
</person>
<person>
<name>Kamil</name>
@ -23,11 +24,12 @@
<Wrinkles>No</Wrinkles>
<Balding>No</Balding>
<Beard>No</Beard>
<Outfit>Messy</Outfit>
<Outfit>Casual</Outfit>
<Glasses>No</Glasses>
<Tattoo>No</Tattoo>
<Hair>Color</Hair>
<Hair>Natural</Hair>
<Behaviour>Energetic</Behaviour>
<Image>database\\clientsimg\\DavidBowie.png</Image>
</person>
<person>
<name>Jon</name>
@ -36,14 +38,15 @@
<favoriteMeal>Grochówka</favoriteMeal>
<restrictions>Vegan</restrictions>
<Wrinkles>No</Wrinkles>
<Balding>No</Balding>
<Beard>No</Beard>
<Balding>Yes</Balding>
<Beard>Yes</Beard>
<Outfit>Messy</Outfit>
<Glasses>No</Glasses>
<Tattoo>No</Tattoo>
<Tattoo>Yes</Tattoo>
<Hair>Color</Hair>
<Behaviour>Energetic</Behaviour>
</person>
<Behaviour>Stressed</Behaviour>
<Image>database\\clientsimg\\DavidBowie.png</Image>
</person>
<person>
<name>Andrzej</name>
<surname>Kowalski</surname>
@ -52,11 +55,12 @@
<restrictions>Meat</restrictions>
<Wrinkles>No</Wrinkles>
<Balding>No</Balding>
<Beard>No</Beard>
<Outfit>Messy</Outfit>
<Glasses>No</Glasses>
<Beard>Yes</Beard>
<Outfit>Formal</Outfit>
<Glasses>Yes</Glasses>
<Tattoo>No</Tattoo>
<Hair>Color</Hair>
<Behaviour>Energetic</Behaviour>
<Hair>Grey</Hair>
<Behaviour>Calm</Behaviour>
<Image>database\\clientsimg\\AndrzejKowalski.png</Image>
</person>
</people>

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View File

@ -64,13 +64,13 @@ dot_data = tree.export_graphviz(clf_en, out_file=None,
"Behaviour": random.choice(['Energetic', 'Stressed', 'Calm'])
} """
def predict_client(client_data):
def predict_client(client_data,glasses):
new_client = {
"Wrinkles": client_data.zmarszczki,
"Balding": client_data.lysienie,
"Beard": client_data.broda,
"Outfit": client_data.ubior,
"Glasses": client_data.okulary,
"Glasses": 'Yes' if glasses == 1 else 'No',
"Tattoo": client_data.tatuaz,
"Hair": client_data.wlosy,
"Behaviour": client_data.zachowanie