dodanie decision tree
This commit is contained in:
parent
25f597dca2
commit
537173bc70
@ -2,9 +2,11 @@
|
||||
<project version="4">
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="">
|
||||
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/decision_tree.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/functions.py" beforeDir="false" afterPath="$PROJECT_DIR$/functions.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/Archiwum/frontend/js/main.js" beforeDir="false" afterPath="$PROJECT_DIR$/Archiwum/frontend/js/main.js" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/data.py" beforeDir="false" afterPath="$PROJECT_DIR$/data.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/field.py" beforeDir="false" afterPath="$PROJECT_DIR$/field.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/main.py" beforeDir="false" afterPath="$PROJECT_DIR$/main.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
|
||||
@ -27,9 +29,14 @@
|
||||
<component name="PropertiesComponent">
|
||||
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
|
||||
<property name="WebServerToolWindowFactoryState" value="false" />
|
||||
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
|
||||
<property name="last_opened_file_path" value="$PROJECT_DIR$/decision_tree.py" />
|
||||
<property name="restartRequiresConfirmation" value="false" />
|
||||
</component>
|
||||
<component name="RecentsManager">
|
||||
<key name="MoveFile.RECENT_KEYS">
|
||||
<recent name="D:\Studia\Projects\AL-2020" />
|
||||
</key>
|
||||
</component>
|
||||
<component name="RunDashboard">
|
||||
<option name="ruleStates">
|
||||
<list>
|
||||
@ -65,7 +72,7 @@
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
||||
<configuration name="main" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="wozek" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
@ -87,9 +94,12 @@
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<list>
|
||||
<item itemvalue="Python.main" />
|
||||
<item itemvalue="Python.board" />
|
||||
</list>
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.main" />
|
||||
<item itemvalue="Python.board" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
@ -108,7 +118,9 @@
|
||||
<workItem from="1589233530634" duration="769000" />
|
||||
<workItem from="1589543001064" duration="78000" />
|
||||
<workItem from="1589543305930" duration="10474000" />
|
||||
<workItem from="1589561555146" duration="3374000" />
|
||||
<workItem from="1589561555146" duration="3518000" />
|
||||
<workItem from="1589727068958" duration="5729000" />
|
||||
<workItem from="1589796372999" duration="4340000" />
|
||||
</task>
|
||||
<servers />
|
||||
</component>
|
||||
@ -120,15 +132,18 @@
|
||||
<map>
|
||||
<entry key="MAIN">
|
||||
<value>
|
||||
<State />
|
||||
<State>
|
||||
<option name="COLUMN_ORDER" />
|
||||
</State>
|
||||
</value>
|
||||
</entry>
|
||||
</map>
|
||||
</option>
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/AL_2020$decision_tree.coverage" NAME="main Coverage Results" MODIFIED="1589815629629" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/wozek$board.coverage" NAME="board Coverage Results" MODIFIED="1589210811600" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/AL_2020$main.coverage" NAME="main Coverage Results" MODIFIED="1589564478428" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/AL_2020$main.coverage" NAME="main Coverage Results" MODIFIED="1589729320403" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
<SUITE FILE_PATH="coverage/wozek$main.coverage" NAME="main Coverage Results" MODIFIED="1589556038208" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
||||
</component>
|
||||
</project>
|
File diff suppressed because it is too large
Load Diff
BIN
__pycache__/data.cpython-37.pyc
Normal file
BIN
__pycache__/data.cpython-37.pyc
Normal file
Binary file not shown.
BIN
__pycache__/decision_tree.cpython-37.pyc
Normal file
BIN
__pycache__/decision_tree.cpython-37.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
__pycache__/sweets.cpython-37.pyc
Normal file
BIN
__pycache__/sweets.cpython-37.pyc
Normal file
Binary file not shown.
9
candy.py
Normal file
9
candy.py
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
|
||||
class Candy:
|
||||
def __init__(self, producent, type, price):
|
||||
|
||||
self.producent = producent
|
||||
self.type = type
|
||||
self.price = price
|
||||
|
34
data.py
34
data.py
@ -25,4 +25,36 @@ def createDataSweets():
|
||||
sweet = Sweets('Maoam', 'truskawkowy', 'guma', 'maly', 0.25)
|
||||
allProducts.append(sweet)
|
||||
|
||||
return allProducts
|
||||
return allProducts
|
||||
|
||||
|
||||
learning_data = [
|
||||
# kolor, kształt, waga, rozmiar, nazwa
|
||||
['black', 'rectangle', 51, 'small', 'Mars'],
|
||||
['gold', 'pack', 100, 'big', 'Haribo'],
|
||||
['purple', 'rectangle', 100, 'big', 'Milka'],
|
||||
['brown', 'pack', 45, 'small', 'M&M'],
|
||||
['blue', 'rectangle', 50, 'medium', 'Bounty'],
|
||||
['blue', 'square', 40, 'small', 'Knoppers'],
|
||||
['blue', 'rectangle', 35, 'small', 'Milky-way'],
|
||||
['gold', 'rectangle', 40, 'medium', 'Twix'],
|
||||
['gold', 'rectangle', 50, 'medium', 'Prince-polo'],
|
||||
['brown', 'rectangle', 55, 'medium', 'Snickers'],
|
||||
['brown', 'rectangle', 45, 'medium', 'Lion'],
|
||||
['white', 'rectangle', 40, 'medium', 'Kinder-bueno'],
|
||||
['red', 'rectangle', 50, 'medium', 'Kit-kat'],
|
||||
['blue', 'rectangle', 115, 'big', 'Wedel'],
|
||||
['white', 'rectangle', 15, 'small', 'Krowka'],
|
||||
['red', 'pack', 70, 'medium', 'Skittles'],
|
||||
['orange', 'rectangle', 45, 'medium', 'Reeses'],
|
||||
['blue', 'rectangle', 55, 'medium', 'Oreo'],
|
||||
['gold', 'rectangle', 120, 'big', 'Ferrero-rocher'],
|
||||
['white', 'rectangle', 120, 'big', 'Rafaello'],
|
||||
['white', 'jar', 600, 'big', 'Nutella'],
|
||||
['white', 'rectangle', 25, 'small', 'Duplo'],
|
||||
['brown', 'jar', 500, 'big', 'GoOn'],
|
||||
['brown', 'jar', 470, 'big', 'Active Orzechowe'],
|
||||
['red', 'jar', 250, 'medium', 'Strawberry Jam'],
|
||||
['black', 'jar', 250, 'medium', 'Blackberry Jam'],
|
||||
['orange', 'jar', 250, 'medium', 'Peach Jam'],
|
||||
]
|
184
decision_tree.py
Normal file
184
decision_tree.py
Normal file
@ -0,0 +1,184 @@
|
||||
import data
|
||||
|
||||
training_data = data.learning_data
|
||||
|
||||
header = ['color', 'shape', 'weight', 'size', 'name']
|
||||
|
||||
|
||||
# funkcja która zwraca listę unikalnych wartości z każdej kolumny
|
||||
def uniqie_vals(rows, col):
|
||||
return set([row[col] for row in rows])
|
||||
|
||||
|
||||
# zliczamy liczbę wystąpień danego typu w zestawie danych
|
||||
def class_counts(rows):
|
||||
counts = {} # label -> count
|
||||
|
||||
for row in rows:
|
||||
name = row[-1]
|
||||
if name not in counts:
|
||||
counts[name] = 0
|
||||
counts[name] += 1
|
||||
return counts
|
||||
|
||||
|
||||
# funkcja do sprawdzania czy wartość jest wartością numeryczną
|
||||
def is_numeric(val):
|
||||
return isinstance(val, int) or isinstance(val, float)
|
||||
|
||||
|
||||
# klasa do zadawania pytań
|
||||
class Question:
|
||||
def __init__(self, column, value):
|
||||
self.column = column
|
||||
self.value = value
|
||||
|
||||
def match(self, example):
|
||||
val = example[self.column]
|
||||
|
||||
if is_numeric(val):
|
||||
return val >= self.value
|
||||
else:
|
||||
return val == self.value
|
||||
|
||||
def __repr__(self):
|
||||
condition = '=='
|
||||
if is_numeric(self.value):
|
||||
condition = '>='
|
||||
return "Is %s %s %s?" % (header[self.column], condition, str(self.value))
|
||||
|
||||
|
||||
def partition(rows, question):
|
||||
""" podział zbioru informacji
|
||||
dla każdego rzędu w zbiorze, sprawdź czy zgadza się z pytaniem, jeśli tak
|
||||
dodaj do 'true' inaczej dodaj do 'false' """
|
||||
true_rows, false_rows = [], []
|
||||
for row in rows:
|
||||
if question.match(row):
|
||||
true_rows.append(row)
|
||||
else:
|
||||
false_rows.append(row)
|
||||
return true_rows, false_rows
|
||||
|
||||
|
||||
def gini(rows):
|
||||
""" Gini impurity is a measure of how often a randomly chosen element from
|
||||
the set would be incorrectly labeled if it was randomly labeled according to
|
||||
the distribution of labels in the subset. """
|
||||
|
||||
counts = class_counts(rows)
|
||||
impurity = 1
|
||||
for lbl in counts:
|
||||
prob_of_lbl = counts[lbl] / float(len(rows))
|
||||
impurity -= prob_of_lbl ** 2
|
||||
return impurity
|
||||
|
||||
|
||||
def info_gain(left, right, current_uncertainty):
|
||||
p = float(len(left)) / (len(left) + len(right))
|
||||
return current_uncertainty - p * gini(left) - (1 - p) * gini(right)
|
||||
|
||||
|
||||
def find_best_split(rows):
|
||||
""" znajdź najlepsze możliwe pytanie do zadania, sprawdzając wszystkie
|
||||
właściwośći oraz licząc dla nich 'info_gain' """
|
||||
best_gain = 0
|
||||
best_question = None
|
||||
current_uncertainty = gini(rows)
|
||||
n_features = len(rows[0]) - 1
|
||||
|
||||
for col in range(n_features):
|
||||
values = set([row[col] for row in rows])
|
||||
|
||||
for val in values:
|
||||
question = Question(col, val)
|
||||
|
||||
true_rows, false_rows = partition(rows, question)
|
||||
|
||||
if len(true_rows) == 0 or len(false_rows) == 0:
|
||||
continue
|
||||
|
||||
gain = info_gain(true_rows, false_rows, current_uncertainty)
|
||||
|
||||
if gain > best_gain:
|
||||
best_gain, best_question = gain, question
|
||||
|
||||
return best_gain, best_question
|
||||
|
||||
|
||||
class Leaf:
|
||||
def __init__(self, rows):
|
||||
self.predicions = class_counts(rows)
|
||||
|
||||
|
||||
class DecisionNode:
|
||||
def __init__(self, question, true_branch, false_branch):
|
||||
self.question = question
|
||||
self.true_branch = true_branch
|
||||
self.false_branch = false_branch
|
||||
|
||||
|
||||
def build_tree(rows):
|
||||
gain, question = find_best_split(rows)
|
||||
|
||||
if gain == 0:
|
||||
return Leaf(rows)
|
||||
|
||||
true_rows, false_rows = partition(rows, question)
|
||||
|
||||
true_branch = build_tree(true_rows)
|
||||
|
||||
false_branch = build_tree(false_rows)
|
||||
|
||||
return DecisionNode(question, true_branch, false_branch)
|
||||
|
||||
|
||||
def print_tree(node, spacing=""):
|
||||
if isinstance(node, Leaf):
|
||||
print(spacing + "Predict", node.predicions)
|
||||
|
||||
else:
|
||||
print(spacing + str(node.question))
|
||||
|
||||
print(spacing + '--> True:')
|
||||
print_tree(node.true_branch, spacing + " ")
|
||||
|
||||
print(spacing + '--> False:')
|
||||
print_tree(node.false_branch, spacing + " ")
|
||||
|
||||
|
||||
def classify(row, node):
|
||||
if isinstance(node, Leaf):
|
||||
return node.predicions
|
||||
|
||||
if node.question.match(row):
|
||||
return classify(row, node.true_branch)
|
||||
else:
|
||||
return classify(row, node.false_branch)
|
||||
|
||||
|
||||
def print_leaf(counts):
|
||||
probs = []
|
||||
for lbl in counts.keys():
|
||||
probs.append(lbl)
|
||||
return probs
|
||||
|
||||
|
||||
# my_tree = build_tree(training_data)
|
||||
#
|
||||
# print_tree(my_tree)
|
||||
#
|
||||
# testing_data = [
|
||||
# ['gold', 'rectangle', 50, 'medium', 'Name'],
|
||||
# ['brown', 'rectangle', 55, 'medium', 'Snickers'],
|
||||
# ['white', 'rectangle', 120, 'big', 'Name']
|
||||
# ]
|
||||
#
|
||||
# test = ['white', 'rectangle', 120, 'big', 'Name']
|
||||
#
|
||||
# # for row in testing_data:
|
||||
# # print(print_leaf(classify(row, my_tree)))
|
||||
#
|
||||
# wynik = print_leaf(classify(test, my_tree))[0]
|
||||
# print(wynik)
|
||||
|
3
field.py
3
field.py
@ -21,6 +21,9 @@ class Field:
|
||||
self.f = 0
|
||||
self.previous = None
|
||||
|
||||
# Przedmiot, który podnosi agent
|
||||
self.item = []
|
||||
|
||||
# Te rzeczy są potrzebne do wyświetlenia pola
|
||||
self.image = pygame.image.load('img/Field.png')
|
||||
self.rect = self.image.get_rect()
|
||||
|
26
main.py
26
main.py
@ -2,11 +2,13 @@ import pygame
|
||||
import functions
|
||||
import sys
|
||||
import time
|
||||
import decision_tree
|
||||
import data
|
||||
|
||||
from agent import Agent
|
||||
from settings import Settings
|
||||
from board import create_board, draw_board
|
||||
from random import randint
|
||||
from random import randint, choice
|
||||
|
||||
|
||||
# Inicjalizacja programu i utworzenie obiektu ekrany
|
||||
@ -17,10 +19,11 @@ def run():
|
||||
pygame.display.set_caption("Inteligentny wózek widłowy")
|
||||
agent = Agent(screen, 50, 50, "Down")
|
||||
board = create_board(screen)
|
||||
my_tree = decision_tree.build_tree(data.learning_data)
|
||||
|
||||
for row in board:
|
||||
for field in row:
|
||||
print(field.cost_of_travel)
|
||||
# for row in board:
|
||||
# for field in row:
|
||||
# print(field.cost_of_travel)
|
||||
|
||||
path = []
|
||||
next_step = None
|
||||
@ -41,7 +44,10 @@ def run():
|
||||
agent.move_forward(board)
|
||||
print(agent.x, agent.y)
|
||||
elif event.key == pygame.K_SPACE:
|
||||
field = board[randint(0, 9)][randint(0, 9)]
|
||||
board[9][0].item = choice(data.learning_data)
|
||||
print("Wybrano: " + board[9][0].item[-1])
|
||||
board[9][0].item[-1] = 'Something'
|
||||
field = board[9][0]
|
||||
if not field.is_shelf:
|
||||
path = functions.a_star(board[agent.y][agent.x], field, board)
|
||||
path.pop(len(path) - 1)
|
||||
@ -61,12 +67,14 @@ def run():
|
||||
for field in row:
|
||||
if not field.is_shelf:
|
||||
field.image = pygame.image.load('img/Field.png')
|
||||
for row in board:
|
||||
for field in row:
|
||||
print(field.g, field.h, field.f, field.previous)
|
||||
|
||||
else:
|
||||
functions.change_turn(agent, next_step)
|
||||
print(agent.x, agent.y)
|
||||
|
||||
if board[agent.y][agent.x].item:
|
||||
prediction = decision_tree.print_leaf(decision_tree.classify(board[agent.y][agent.x].item, my_tree))
|
||||
print("Agent uważa, że przedmiot to: " + prediction[0])
|
||||
board[agent.y][agent.x].item = []
|
||||
|
||||
draw_board(board)
|
||||
agent.blitme()
|
||||
|
Loading…
Reference in New Issue
Block a user