decisiontree with network

This commit is contained in:
Kosterix08 2024-06-12 00:24:57 +02:00
parent fc75709739
commit 1be5d0fe09
11 changed files with 91 additions and 41 deletions

38
app.py
View File

@ -11,10 +11,10 @@ import threading
import time import time
import random import random
from classes.data.klient import Klient from classes.data.klient import Klient
from classes.data.klient import Klient from classes.data.klient import KlientCechy
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from decisiontree import predict_client from decisiontree import predict_client
from classes.Jimmy_Neuron.predict_image import predict_image
pygame.init() pygame.init()
window = pygame.display.set_mode((prefs.WIDTH, prefs.HEIGHT)) window = pygame.display.set_mode((prefs.WIDTH, prefs.HEIGHT))
@ -102,9 +102,8 @@ agent = Agent(prefs.SPAWN_POINT[0], prefs.SPAWN_POINT[1], cells)
klient = Klient(prefs.GRID_SIZE-1, 17,cells) klient = Klient(prefs.GRID_SIZE-1, 17,cells)
target_x, target_y = klientx_target-1, klienty_target target_x, target_y = klientx_target-1, klienty_target
def watekDlaSciezkiAgenta(): def watekDlaSciezkiAgenta():
assigned = False
time.sleep(3) time.sleep(3)
while True: while True:
if len(path) > 0: if len(path) > 0:
@ -119,10 +118,26 @@ def watekDlaSciezkiAgenta():
elif isinstance(element, tuple): # Check if it's a tuple indicating movement coordinates elif isinstance(element, tuple): # Check if it's a tuple indicating movement coordinates
x, y = element x, y = element
agent.moveto(x, y) agent.moveto(x, y)
neighbors = agent.get_neighbors(agent.current_cell, agent.cells)
for neighbor in neighbors:
if neighbor == klient.current_cell:
if not assigned:
random_client_data = random.choice(clients)
glasses = predict_image(random_client_data.zdjecie)
prediction = predict_client(random_client_data, glasses)
print("\nClient data:")
print(random_client_data)
print("Prediction (Adult):", prediction)
assigned = True
break
if assigned:
break
time.sleep(1) time.sleep(1)
def watekDlaSciezkiKlienta(): def watekDlaSciezkiKlienta():
assigned = False
time.sleep(3) time.sleep(3)
while True: while True:
if len(path2) > 0: if len(path2) > 0:
@ -138,21 +153,13 @@ def watekDlaSciezkiKlienta():
x, y = element2 x, y = element2
klient.moveto(x, y) klient.moveto(x, y)
if not assigned and klient.current_cell == cells[klientx_target][klienty_target]: if klient.current_cell == cells[klientx_target][klienty_target]:
klient.przyStoliku = True klient.przyStoliku = True
klient.stolik = klient.current_cell klient.stolik = klient.current_cell
random_client_data = random.choice(clients)
prediction = predict_client(random_client_data)
print("\nClient data:")
print(random_client_data)
print("Prediction (Adult):", prediction)
assigned = True
if assigned:
break
time.sleep(1) time.sleep(1)
path2 = klient.bfs2(klientx_target, klienty_target) path2 = klient.bfs2(klientx_target, klienty_target)
print("Najkrótsza ścieżka:", path2) print("Najkrótsza ścieżka:", path2)
watek = threading.Thread(target=watekDlaSciezkiKlienta) watek = threading.Thread(target=watekDlaSciezkiKlienta)
@ -165,6 +172,7 @@ watek = threading.Thread(target=watekDlaSciezkiAgenta)
watek.daemon = True watek.daemon = True
watek.start() watek.start()
running = True running = True
while running: while running:
for event in pygame.event.get(): for event in pygame.event.get():

View File

@ -0,0 +1,36 @@
import torch
from torchvision.transforms import Compose, Lambda
import torchvision.io as io
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
hidden_size = 135 * 64
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 6, 5),
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(6, 16, 5),
torch.nn.Flatten(),
torch.nn.Linear(53824, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 32 * 32),
torch.nn.ReLU(),
torch.nn.Linear(32 * 32, 10),
torch.nn.LogSoftmax(dim=-1)
).to(device)
model.load_state_dict(torch.load('model.pt', map_location=device))
model.eval()
def predict_image(image_path):
transform = Compose([Lambda(lambda x: x.float())])
image = io.read_image(image_path, mode=io.ImageReadMode.UNCHANGED)
image = transform(image)
image = image.unsqueeze(0)
image = image.to(device)
with torch.no_grad():
output = model(image)
predicted_class = output.argmax(dim=1).item()
print(predicted_class)
return predicted_class

View File

@ -292,3 +292,5 @@ class Agent:

View File

@ -85,8 +85,8 @@ for person in root.findall('person'):
outfit_element = person.find('Outfit') outfit_element = person.find('Outfit')
outfit = outfit_element.text if outfit_element is not None else '' outfit = outfit_element.text if outfit_element is not None else ''
glasses_element = person.find('Glasses') image_element = person.find('Image')
glasses = glasses_element.text if glasses_element is not None else '' image = image_element.text if image_element is not None else ''
tattoo_element = person.find('Tattoo') tattoo_element = person.find('Tattoo')
tattoo = tattoo_element.text if tattoo_element is not None else '' tattoo = tattoo_element.text if tattoo_element is not None else ''
@ -107,10 +107,10 @@ for person in root.findall('person'):
'lysienie': balding, 'lysienie': balding,
'broda': beard, 'broda': beard,
'ubior': outfit, 'ubior': outfit,
'okulary': glasses,
'tatuaz': tattoo, 'tatuaz': tattoo,
'wlosy': hair, 'wlosy': hair,
'zachowanie': behaviour 'zachowanie': behaviour,
'zdjecie': image
} }
clients.append(KlientCechy(**person_data)) clients.append(KlientCechy(**person_data))

View File

@ -14,7 +14,7 @@ class Klient:
self.current_cell = cells[x][y] self.current_cell = cells[x][y]
self.current_x = x self.current_x = x
self.current_y = y self.current_y = y
przyStoliku = False self.przyStoliku = False
self.cells = cells self.cells = cells
self.X = x self.X = x
self.Y = y self.Y = y
@ -183,7 +183,7 @@ class Klient:
class KlientCechy: class KlientCechy:
def __init__(self,imie,nazwisko,wiek,ulubiony_posilek,restrykcje_dietowe,zmarszczki, lysienie, broda, ubior, okulary, tatuaz, wlosy, zachowanie): def __init__(self,imie,nazwisko,wiek,ulubiony_posilek,restrykcje_dietowe,zmarszczki, lysienie, broda, ubior, tatuaz, wlosy, zachowanie, zdjecie):
self.imie = imie self.imie = imie
self.nazwisko = nazwisko self.nazwisko = nazwisko
self.wiek = wiek self.wiek = wiek
@ -193,7 +193,7 @@ class KlientCechy:
self.lysienie = lysienie self.lysienie = lysienie
self.broda = broda self.broda = broda
self.ubior = ubior self.ubior = ubior
self.okulary = okulary self.zdjecie = zdjecie
self.tatuaz = tatuaz self.tatuaz = tatuaz
self.wlosy = wlosy self.wlosy = wlosy
self.zachowanie = zachowanie self.zachowanie = zachowanie
@ -210,4 +210,4 @@ class KlientCechy:
print("Klient ma juz przypisany stolik.") print("Klient ma juz przypisany stolik.")
def __str__(self): def __str__(self):
return f"Klient: {self.imie} {self.nazwisko} {self.wiek}, ulubione Danie: {self.ulubiony_posilek}, restrykcje diet: {self.restrykcje_dietowe}. Jego cechy to: zmarszczki: {self.zmarszczki}, lysienie: {self.lysienie}, broda: {self.broda}, ubior: {self.ubior}, okulary: {self.okulary}, tatuaz: {self.tatuaz}, wlosy: {self.wlosy}, zachowanie: {self.zachowanie}" return f"Klient: {self.imie} {self.nazwisko} {self.wiek}, ulubione Danie: {self.ulubiony_posilek}, restrykcje diet: {self.restrykcje_dietowe}. Jego cechy to: zmarszczki: {self.zmarszczki}, lysienie: {self.lysienie}, broda: {self.broda}, ubior: {self.ubior}, tatuaz: {self.tatuaz}, wlosy: {self.wlosy}, zachowanie: {self.zachowanie}"

View File

@ -6,13 +6,14 @@
<favoriteMeal>Tatar</favoriteMeal> <favoriteMeal>Tatar</favoriteMeal>
<restrictions>Meat</restrictions> <restrictions>Meat</restrictions>
<Wrinkles>No</Wrinkles> <Wrinkles>No</Wrinkles>
<Balding>Yes</Balding> <Balding>No</Balding>
<Beard>Yes</Beard> <Beard>No</Beard>
<Outfit>Messy</Outfit> <Outfit>Messy</Outfit>
<Glasses>No</Glasses> <Glasses>Yes</Glasses>
<Tattoo>No</Tattoo> <Tattoo>No</Tattoo>
<Hair>Color</Hair> <Hair>Natural</Hair>
<Behaviour>Energetic</Behaviour> <Behaviour>Energetic</Behaviour>
<Image>database\\clientsimg\\DavidBowie.png</Image>
</person> </person>
<person> <person>
<name>Kamil</name> <name>Kamil</name>
@ -23,11 +24,12 @@
<Wrinkles>No</Wrinkles> <Wrinkles>No</Wrinkles>
<Balding>No</Balding> <Balding>No</Balding>
<Beard>No</Beard> <Beard>No</Beard>
<Outfit>Messy</Outfit> <Outfit>Casual</Outfit>
<Glasses>No</Glasses> <Glasses>No</Glasses>
<Tattoo>No</Tattoo> <Tattoo>No</Tattoo>
<Hair>Color</Hair> <Hair>Natural</Hair>
<Behaviour>Energetic</Behaviour> <Behaviour>Energetic</Behaviour>
<Image>database\\clientsimg\\DavidBowie.png</Image>
</person> </person>
<person> <person>
<name>Jon</name> <name>Jon</name>
@ -36,13 +38,14 @@
<favoriteMeal>Grochówka</favoriteMeal> <favoriteMeal>Grochówka</favoriteMeal>
<restrictions>Vegan</restrictions> <restrictions>Vegan</restrictions>
<Wrinkles>No</Wrinkles> <Wrinkles>No</Wrinkles>
<Balding>No</Balding> <Balding>Yes</Balding>
<Beard>No</Beard> <Beard>Yes</Beard>
<Outfit>Messy</Outfit> <Outfit>Messy</Outfit>
<Glasses>No</Glasses> <Glasses>No</Glasses>
<Tattoo>No</Tattoo> <Tattoo>Yes</Tattoo>
<Hair>Color</Hair> <Hair>Color</Hair>
<Behaviour>Energetic</Behaviour> <Behaviour>Stressed</Behaviour>
<Image>database\\clientsimg\\DavidBowie.png</Image>
</person> </person>
<person> <person>
<name>Andrzej</name> <name>Andrzej</name>
@ -52,11 +55,12 @@
<restrictions>Meat</restrictions> <restrictions>Meat</restrictions>
<Wrinkles>No</Wrinkles> <Wrinkles>No</Wrinkles>
<Balding>No</Balding> <Balding>No</Balding>
<Beard>No</Beard> <Beard>Yes</Beard>
<Outfit>Messy</Outfit> <Outfit>Formal</Outfit>
<Glasses>No</Glasses> <Glasses>Yes</Glasses>
<Tattoo>No</Tattoo> <Tattoo>No</Tattoo>
<Hair>Color</Hair> <Hair>Grey</Hair>
<Behaviour>Energetic</Behaviour> <Behaviour>Calm</Behaviour>
<Image>database\\clientsimg\\AndrzejKowalski.png</Image>
</person> </person>
</people> </people>

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View File

@ -64,13 +64,13 @@ dot_data = tree.export_graphviz(clf_en, out_file=None,
"Behaviour": random.choice(['Energetic', 'Stressed', 'Calm']) "Behaviour": random.choice(['Energetic', 'Stressed', 'Calm'])
} """ } """
def predict_client(client_data): def predict_client(client_data,glasses):
new_client = { new_client = {
"Wrinkles": client_data.zmarszczki, "Wrinkles": client_data.zmarszczki,
"Balding": client_data.lysienie, "Balding": client_data.lysienie,
"Beard": client_data.broda, "Beard": client_data.broda,
"Outfit": client_data.ubior, "Outfit": client_data.ubior,
"Glasses": client_data.okulary, "Glasses": 'Yes' if glasses == 1 else 'No',
"Tattoo": client_data.tatuaz, "Tattoo": client_data.tatuaz,
"Hair": client_data.wlosy, "Hair": client_data.wlosy,
"Behaviour": client_data.zachowanie "Behaviour": client_data.zachowanie