Merge pull request 'Add decision tree' (#20) from decision_tree into main
Reviewed-on: #20 Reviewed-by: Tim Barvenov <timbar@st.amu.edu.pl>
This commit is contained in:
commit
3283eaa46d
22
decisionTree/data.csv
Normal file
22
decisionTree/data.csv
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
1-2-3-4-5;1-green 2-yellow 3-orange 4-black 5-while 6-blue;in dB 0-100;0-24;0/1;in cm;in C;0/1
|
||||||
|
Size;Color;Sound;Time;Smell;Height;Temperature;ToRemove
|
||||||
|
1;2;0;16;1;10;25;1
|
||||||
|
2;1;0;12;0;50;24;0
|
||||||
|
2;3;30;13;1;38;38;0
|
||||||
|
1;4;0;7;1;5;27;1
|
||||||
|
1;2;0;16;1;10;25;1
|
||||||
|
2;1;0;12;0;50;24;0
|
||||||
|
2;3;30;13;1;38;38;0
|
||||||
|
1;4;0;7;1;5;27;1
|
||||||
|
1;2;0;16;1;10;25;1
|
||||||
|
2;1;0;12;0;50;24;0
|
||||||
|
2;3;30;13;1;38;38;0
|
||||||
|
1;4;0;7;1;5;27;1
|
||||||
|
1;2;0;16;1;10;25;1
|
||||||
|
2;1;0;12;0;50;24;0
|
||||||
|
2;3;30;13;1;38;38;0
|
||||||
|
1;4;0;7;1;5;27;1
|
||||||
|
1;2;0;16;1;10;25;1
|
||||||
|
2;1;0;12;0;50;24;0
|
||||||
|
2;3;30;13;1;38;38;0
|
||||||
|
1;4;0;7;1;5;27;1
|
|
BIN
decisionTree/decision_tree_model.pkl
Normal file
BIN
decisionTree/decision_tree_model.pkl
Normal file
Binary file not shown.
9
decisionTree/evaluate.py
Normal file
9
decisionTree/evaluate.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import joblib
|
||||||
|
|
||||||
|
def evaluate(data):
|
||||||
|
# Load the model
|
||||||
|
clf = joblib.load('decisionTree/decision_tree_model.pkl')
|
||||||
|
|
||||||
|
# Make a prediction
|
||||||
|
prediction = clf.predict(data)
|
||||||
|
return prediction
|
21
decisionTree/prepare.py
Normal file
21
decisionTree/prepare.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn import metrics
|
||||||
|
import joblib
|
||||||
|
|
||||||
|
pima = pd.read_csv("data.csv", header=1, delimiter=';')
|
||||||
|
|
||||||
|
feature_cols = ['Size', 'Color', 'Sound', 'Time','Smell', 'Height','Temperature']
|
||||||
|
X = pima[feature_cols]
|
||||||
|
y = pima.ToRemove
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
|
||||||
|
clf = DecisionTreeClassifier()
|
||||||
|
clf = clf.fit(X_train,y_train)
|
||||||
|
|
||||||
|
joblib.dump(clf, 'decision_tree_model.pkl')
|
||||||
|
|
||||||
|
y_pred = clf.predict(X_test)
|
||||||
|
|
||||||
|
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
|
@ -13,3 +13,5 @@ class Cat(Entity):
|
|||||||
self.busy = False
|
self.busy = False
|
||||||
self.sleeping = False
|
self.sleeping = False
|
||||||
self.direction = 0
|
self.direction = 0
|
||||||
|
|
||||||
|
self.props = [1,2,0,16,1,10,25]
|
8
domain/entities/earring.py
Normal file
8
domain/entities/earring.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from domain.entities.entity import Entity
|
||||||
|
from domain.world import World
|
||||||
|
|
||||||
|
|
||||||
|
class Earring(Entity):
|
||||||
|
def __init__(self, x: int, y: int):
|
||||||
|
super().__init__(x, y, "EARRING")
|
||||||
|
self.props = [2,1,0,12,0,50,24]
|
@ -1,11 +1,11 @@
|
|||||||
from domain.entities.entity import Entity
|
from domain.entities.entity import Entity
|
||||||
from domain.world import World
|
|
||||||
|
|
||||||
|
|
||||||
class Garbage(Entity):
|
class Garbage(Entity):
|
||||||
def __init__(self, x: int, y: int):
|
def __init__(self, x: int, y: int):
|
||||||
super().__init__(x, y, "GARBAGE")
|
super().__init__(x, y, "PEEL")
|
||||||
self.wet = False
|
self.wet = False
|
||||||
self.size = 0
|
self.size = 0
|
||||||
|
self.props = [1,2,0,16,1,10,25]
|
||||||
|
|
||||||
# TODO GARBAGE: add more properties
|
# TODO GARBAGE: add more properties
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from decisionTree.evaluate import evaluate
|
||||||
from domain.entities.entity import Entity
|
from domain.entities.entity import Entity
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +16,8 @@ class World:
|
|||||||
def add_entity(self, entity: Entity):
|
def add_entity(self, entity: Entity):
|
||||||
if entity.type == "PEEL":
|
if entity.type == "PEEL":
|
||||||
self.dust[entity.x][entity.y].append(entity)
|
self.dust[entity.x][entity.y].append(entity)
|
||||||
|
elif entity.type == "EARRING":
|
||||||
|
self.dust[entity.x][entity.y].append(entity)
|
||||||
elif entity.type == "VACUUM":
|
elif entity.type == "VACUUM":
|
||||||
self.vacuum = entity
|
self.vacuum = entity
|
||||||
elif entity.type == "DOC_STATION":
|
elif entity.type == "DOC_STATION":
|
||||||
@ -29,7 +32,10 @@ class World:
|
|||||||
return bool(self.obstacles[x][y])
|
return bool(self.obstacles[x][y])
|
||||||
|
|
||||||
def is_garbage_at(self, x: int, y: int) -> bool:
|
def is_garbage_at(self, x: int, y: int) -> bool:
|
||||||
return bool(self.dust[x][y])
|
if len(self.dust[x][y]) == 0:
|
||||||
|
return False
|
||||||
|
tmp = evaluate([self.dust[x][y][0].props])
|
||||||
|
return bool(tmp[0])
|
||||||
|
|
||||||
def is_docking_station_at(self, x: int, y: int) -> bool:
|
def is_docking_station_at(self, x: int, y: int) -> bool:
|
||||||
return bool(self.doc_station.x == x and self.doc_station.y == y)
|
return bool(self.doc_station.x == x and self.doc_station.y == y)
|
||||||
|
5
main.py
5
main.py
@ -8,6 +8,8 @@ from domain.commands.vacuum_move_command import VacuumMoveCommand
|
|||||||
from domain.entities.cat import Cat
|
from domain.entities.cat import Cat
|
||||||
from domain.entities.entity import Entity
|
from domain.entities.entity import Entity
|
||||||
from domain.entities.vacuum import Vacuum
|
from domain.entities.vacuum import Vacuum
|
||||||
|
from domain.entities.garbage import Garbage
|
||||||
|
from domain.entities.earring import Earring
|
||||||
from domain.entities.docking_station import Doc_Station
|
from domain.entities.docking_station import Doc_Station
|
||||||
from domain.world import World
|
from domain.world import World
|
||||||
from view.renderer import Renderer
|
from view.renderer import Renderer
|
||||||
@ -135,7 +137,7 @@ def generate_world(tiles_x: int, tiles_y: int) -> World:
|
|||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
temp_x = randint(0, tiles_x - 1)
|
temp_x = randint(0, tiles_x - 1)
|
||||||
temp_y = randint(0, tiles_y - 1)
|
temp_y = randint(0, tiles_y - 1)
|
||||||
world.add_entity(Entity(temp_x, temp_y, "PEEL"))
|
world.add_entity(Garbage(temp_x, temp_y))
|
||||||
world.vacuum = Vacuum(1, 1)
|
world.vacuum = Vacuum(1, 1)
|
||||||
world.doc_station = Doc_Station(9, 8)
|
world.doc_station = Doc_Station(9, 8)
|
||||||
if config.getboolean("APP", "cat"):
|
if config.getboolean("APP", "cat"):
|
||||||
@ -146,6 +148,7 @@ def generate_world(tiles_x: int, tiles_y: int) -> World:
|
|||||||
world.add_entity(Entity(3, 4, "PLANT2"))
|
world.add_entity(Entity(3, 4, "PLANT2"))
|
||||||
world.add_entity(Entity(8, 8, "PLANT2"))
|
world.add_entity(Entity(8, 8, "PLANT2"))
|
||||||
world.add_entity(Entity(9, 3, "PLANT3"))
|
world.add_entity(Entity(9, 3, "PLANT3"))
|
||||||
|
world.add_entity(Earring(5, 5))
|
||||||
return world
|
return world
|
||||||
|
|
||||||
|
|
||||||
|
BIN
media/sprites/earrings.webp
Normal file
BIN
media/sprites/earrings.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 5.3 KiB |
@ -1,3 +1,6 @@
|
|||||||
pygame
|
pygame
|
||||||
configparser
|
configparser
|
||||||
formaFormatting: Provider - black
|
pandas
|
||||||
|
scikit-learn
|
||||||
|
joblib
|
||||||
|
# formaFormatting: Provider - black
|
@ -94,6 +94,13 @@ class Renderer:
|
|||||||
self.tile_height + self.tile_height / 4,
|
self.tile_height + self.tile_height / 4,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
"EARRING": pygame.transform.scale(
|
||||||
|
pygame.image.load("media/sprites/earrings.webp"),
|
||||||
|
(
|
||||||
|
self.tile_width + self.tile_width / 4,
|
||||||
|
self.tile_height + self.tile_height / 4,
|
||||||
|
),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.cat_direction_sprite = {
|
self.cat_direction_sprite = {
|
||||||
|
Loading…
Reference in New Issue
Block a user