4.9 KiB
Wojciech Łukasik - drzewa decyzyjne, algorytm CART
Opis podprojektu
Podprojekt implementuje tworzenie drzewa decyzyjnego w oparciu o algorytm CART (Classification And Regression Tree), które pomaga Agentowi w rozpoznaniu słodyczy na podstawie ich cech fizycznych (kolor, kształt, masa, rozmiar).
Wszystkie funkcje oraz klasy wykorzystywane w tym podprojekcie znajdują się w pliku decision_tree.py, dane uczące znajdują się w pliku data.py w liście learning_data
Tworzenie drzewa decyzyjnego
Główną funkcją jest
build_tree(rows)
która jak wskazuje nazwa tworzy drzewo.
Funkcja przyjmuje jako argument listę zawierającą zestaw danych, w tym przypadku będą to słodycze o różnych właściwościach.
def build_tree(rows):
gain, question = find_best_split(rows)
if gain == 0:
return Leaf(rows)
true_rows, false_rows = partition(rows, question)
true_branch = build_tree(true_rows)
false_branch = build_tree(false_rows)
return DecisionNode(question, true_branch, false_branch)
Drzewo jest budowane w oparciu o najlepsze możlwe podziały (najbardziej korzystne 'pytanie', które można zadać). Zajmuje się tym funkcja
find_best_split(rows)
która dla wszystkich właściwości przekazanego zestawu informacji
wylicza dla nich 'zysk informacji'.
Jeżeli nie otrzymujemy żadnych informacji (gain == 0) to znaczy, że znajdujemy się w liściu drzewa.
def find_best_split(rows):
best_gain = 0
best_question = None
current_uncertainty = gini(rows)
n_features = len(rows[0]) - 1
for col in range(n_features):
values = set([row[col] for row in rows])
for val in values:
question = Question(col, val)
true_rows, false_rows = partition(rows, question)
if len(true_rows) == 0 or len(false_rows) == 0:
continue
gain = info_gain(true_rows, false_rows, current_uncertainty)
if gain > best_gain:
best_gain, best_question = gain, question
return best_gain, best_question
Zysk informacji z danego podziału otrzymujemy obliczając wartość 'Gini Impurity'. Jest to miara tego jak często losowo wybrany element zbioru byłby źle skategoryzowany, gdyby przypisać mu losową kategorię spośród wszystkich kategorii znajdujących się w danym zbiorze.
def gini(rows):
counts = class_counts(rows)
impurity = 1
for lbl in counts:
prob_of_lbl = counts[lbl] / float(len(rows))
impurity -= prob_of_lbl ** 2
return impurity
class_counts(rows)
to funkcja, która dla danego zestawu danych zwraca wszystkie unikalne 'kategorie' oraz liczbę ich wystąpień.
Dla przykładu, w zestawie w którym wszystkie elementy podchodzą pod tę samą kategorię wartość Gini Impurity będzie równa zero, natomiast w zbiorze w którym znajdują się dwie kategorie wartość ta wyniesie 0,5. Im więcej różnych kategorii tym bardziej wartość Gini Impurity będzie zbliżała się do 1.
Po znalezieniu najbardziej optymalnego pytania, algorytm dzieli zestaw na elementy, dla których pytanie jest prawdziwe
(true_rows), oraz te dla których jest fałszywe (false_rows). Następnie wykonuje rekurencyjnie procedurę build_tree
dla obu poddrzew tak długo aż nie dojdzie do liści.
Element o zadanym zestawie cech, zostaje odnaleziony w drzewie dzięki prostej procedurze
classify(row, node)
'row' to lista cech elementu, natomiast 'node' na początu jest korzeniem już zbudowanego drzewa.
Element jest odnaleziony dzięki rekurencyjnym porównaniom atrybutów elementu z pytaniami w kolejnych węzłach drzewa.
def classify(row, node):
if isinstance(node, Leaf):
return node.predicions
if node.question.match(row):
return classify(row, node.true_branch)
else:
return classify(row, node.false_branch)
Zestaw uczący
Zestaw budujący drzewo to lista zawierająca 27 przykładowych słodyczy. Ich atrybuty zapisane są w formacie ['kolor', 'kształt', 'masa', 'wielkość', 'nazwa']. Oczywiście przy wyszukiwaniu elementu w drzewie jego nazwa nie jest potrzebna ponieważ to jej szukamy. Przykładowe elementy z zestawu uczącego:
['black', 'rectangle', 51, 'small', 'Mars'],
['gold', 'pack', 100, 'big', 'Haribo'],
['purple', 'rectangle', 100, 'big', 'Milka'],
['brown', 'pack', 45, 'small', 'M&M'],
Implementacja w projekcie
Przy rozpoczęciu głównej pętli programu w pliku main.py
drzewo my_tree
zostaje zbudowane w oparciu o dane
data.learning_data
.
Gdy program już działa, po wciśnięciu spacji
jeden ze słodyczy zostanie losowo wybrany z zestawu data.learning_data
oraz umieszczony na polu board[9][0]
, a jego nazwa zostanie wypisana w konsoli. Następnie Agent przemieszcza się do
punktu board[9][0]
i rozpoczne procedurę wyszukiwania elementu w zbudowanym drzewie. Na końcu wypisze w
konsoli nazwę produktu.