2
0
forked from s444420/AL-2020
AL-2020/Raporty/raport_444428.md

132 lines
4.9 KiB
Markdown

# Wojciech Łukasik - drzewa decyzyjne, algorytm CART
### Opis podprojektu
Podprojekt implementuje tworzenie drzewa decyzyjnego w oparciu o algorytm CART
(Classification And Regression Tree), które pomaga Agentowi w rozpoznaniu słodyczy na podstawie
ich cech fizycznych (kolor, kształt, masa, rozmiar).
Wszystkie funkcje oraz klasy wykorzystywane w tym podprojekcie znajdują się w pliku decision_tree.py,
dane uczące znajdują się w pliku data.py w liście learning_data
### Tworzenie drzewa decyzyjnego
Główną funkcją jest
`build_tree(rows)` która jak wskazuje nazwa tworzy drzewo.
Funkcja przyjmuje jako argument listę zawierającą zestaw danych, w tym przypadku będą to słodycze o różnych właściwościach.
```python
def build_tree(rows):
gain, question = find_best_split(rows)
if gain == 0:
return Leaf(rows)
true_rows, false_rows = partition(rows, question)
true_branch = build_tree(true_rows)
false_branch = build_tree(false_rows)
return DecisionNode(question, true_branch, false_branch)
```
Drzewo jest budowane w oparciu o najlepsze możlwe podziały (najbardziej korzystne 'pytanie', które można zadać).
Zajmuje się tym funkcja
`find_best_split(rows)` która dla wszystkich właściwości przekazanego zestawu informacji
wylicza dla nich 'zysk informacji'.
Jeżeli nie otrzymujemy żadnych informacji (gain == 0) to znaczy, że znajdujemy
się w liściu drzewa.
```python
def find_best_split(rows):
best_gain = 0
best_question = None
current_uncertainty = gini(rows)
n_features = len(rows[0]) - 1
for col in range(n_features):
values = set([row[col] for row in rows])
for val in values:
question = Question(col, val)
true_rows, false_rows = partition(rows, question)
if len(true_rows) == 0 or len(false_rows) == 0:
continue
gain = info_gain(true_rows, false_rows, current_uncertainty)
if gain > best_gain:
best_gain, best_question = gain, question
return best_gain, best_question
```
Zysk informacji z danego podziału otrzymujemy obliczając wartość 'Gini Impurity'. Jest to miara tego jak często losowo
wybrany element zbioru byłby źle skategoryzowany, gdyby przypisać mu losową kategorię spośród wszystkich kategorii
znajdujących się w danym zbiorze.
```python
def gini(rows):
counts = class_counts(rows)
impurity = 1
for lbl in counts:
prob_of_lbl = counts[lbl] / float(len(rows))
impurity -= prob_of_lbl ** 2
return impurity
```
`class_counts(rows)` to funkcja, która dla danego zestawu danych zwraca wszystkie unikalne 'kategorie' oraz liczbę ich wystąpień.
Dla przykładu, w zestawie w którym wszystkie elementy podchodzą pod tę samą kategorię wartość Gini Impurity będzie równa zero, natomiast w zbiorze w którym znajdują się dwie kategorie wartość ta wyniesie 0,5. Im więcej różnych kategorii tym bardziej wartość Gini Impurity będzie zbliżała się do 1.
Po znalezieniu najbardziej optymalnego pytania, algorytm dzieli zestaw na elementy, dla których pytanie jest prawdziwe
(true_rows), oraz te dla których jest fałszywe (false_rows). Następnie wykonuje rekurencyjnie procedurę `build_tree` dla obu poddrzew tak długo aż nie dojdzie do liści.
Element o zadanym zestawie cech, zostaje odnaleziony w drzewie dzięki prostej procedurze
`classify(row, node)` 'row' to lista cech elementu, natomiast 'node' na początu jest korzeniem już zbudowanego drzewa.
Element jest odnaleziony dzięki
rekurencyjnym porównaniom atrybutów elementu z pytaniami w kolejnych węzłach drzewa.
```python
def classify(row, node):
if isinstance(node, Leaf):
return node.predicions
if node.question.match(row):
return classify(row, node.true_branch)
else:
return classify(row, node.false_branch)
```
### Zestaw uczący
Zestaw budujący drzewo to lista zawierająca 27 przykładowych słodyczy. Ich atrybuty zapisane są w formacie ['kolor',
'kształt', 'masa', 'wielkość', 'nazwa']. Oczywiście przy wyszukiwaniu elementu w drzewie jego nazwa nie jest potrzebna
ponieważ to jej szukamy. Przykładowe elementy z zestawu uczącego:
```python
['black', 'rectangle', 51, 'small', 'Mars'],
['gold', 'pack', 100, 'big', 'Haribo'],
['purple', 'rectangle', 100, 'big', 'Milka'],
['brown', 'pack', 45, 'small', 'M&M'],
```
### Implementacja w projekcie
Przy rozpoczęciu głównej pętli programu w pliku `main.py` drzewo `my_tree` zostaje zbudowane w oparciu o dane
`data.learning_data`.
Gdy program już działa, po wciśnięciu `spacji` jeden ze słodyczy zostanie losowo wybrany z zestawu `data.learning_data`
oraz umieszczony na polu `board[9][0]`, a jego nazwa zostanie wypisana w konsoli. Następnie Agent przemieszcza się do
punktu `board[9][0]` i rozpoczne procedurę wyszukiwania elementu w zbudowanym drzewie. Na końcu wypisze w
konsoli nazwę produktu.