AL-2020/Raporty/raport_444428.md

4.9 KiB

Wojciech Łukasik - drzewa decyzyjne, algorytm CART

Opis podprojektu

Podprojekt implementuje tworzenie drzewa decyzyjnego w oparciu o algorytm CART (Classification And Regression Tree), które pomaga Agentowi w rozpoznaniu słodyczy na podstawie ich cech fizycznych (kolor, kształt, masa, rozmiar).

Wszystkie funkcje oraz klasy wykorzystywane w tym podprojekcie znajdują się w pliku decision_tree.py, dane uczące znajdują się w pliku data.py w liście learning_data

Tworzenie drzewa decyzyjnego

Główną funkcją jest

build_tree(rows) która jak wskazuje nazwa tworzy drzewo.

Funkcja przyjmuje jako argument listę zawierającą zestaw danych, w tym przypadku będą to słodycze o różnych właściwościach.

def build_tree(rows):
    gain, question = find_best_split(rows)

    if gain == 0:
        return Leaf(rows)

    true_rows, false_rows = partition(rows, question)

    true_branch = build_tree(true_rows)

    false_branch = build_tree(false_rows)

    return DecisionNode(question, true_branch, false_branch)

Drzewo jest budowane w oparciu o najlepsze możlwe podziały (najbardziej korzystne 'pytanie', które można zadać). Zajmuje się tym funkcja

find_best_split(rows) która dla wszystkich właściwości przekazanego zestawu informacji wylicza dla nich 'zysk informacji'.

Jeżeli nie otrzymujemy żadnych informacji (gain == 0) to znaczy, że znajdujemy się w liściu drzewa.

def find_best_split(rows):
    best_gain = 0
    best_question = None
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1

    for col in range(n_features):
        values = set([row[col] for row in rows])

        for val in values:
            question = Question(col, val)

            true_rows, false_rows = partition(rows, question)

            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            gain = info_gain(true_rows, false_rows, current_uncertainty)

            if gain > best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question

Zysk informacji z danego podziału otrzymujemy obliczając wartość 'Gini Impurity'. Jest to miara tego jak często losowo wybrany element zbioru byłby źle skategoryzowany, gdyby przypisać mu losową kategorię spośród wszystkich kategorii znajdujących się w danym zbiorze.

def gini(rows):
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl ** 2
    return impurity

class_counts(rows) to funkcja, która dla danego zestawu danych zwraca wszystkie unikalne 'kategorie' oraz liczbę ich wystąpień.

Dla przykładu, w zestawie w którym wszystkie elementy podchodzą pod tę samą kategorię wartość Gini Impurity będzie równa zero, natomiast w zbiorze w którym znajdują się dwie kategorie wartość ta wyniesie 0,5. Im więcej różnych kategorii tym bardziej wartość Gini Impurity będzie zbliżała się do 1.

Po znalezieniu najbardziej optymalnego pytania, algorytm dzieli zestaw na elementy, dla których pytanie jest prawdziwe (true_rows), oraz te dla których jest fałszywe (false_rows). Następnie wykonuje rekurencyjnie procedurę build_tree dla obu poddrzew tak długo aż nie dojdzie do liści.

Element o zadanym zestawie cech, zostaje odnaleziony w drzewie dzięki prostej procedurze

classify(row, node) 'row' to lista cech elementu, natomiast 'node' na początu jest korzeniem już zbudowanego drzewa.

Element jest odnaleziony dzięki rekurencyjnym porównaniom atrybutów elementu z pytaniami w kolejnych węzłach drzewa.

def classify(row, node):
    if isinstance(node, Leaf):
        return node.predicions

    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

Zestaw uczący

Zestaw budujący drzewo to lista zawierająca 27 przykładowych słodyczy. Ich atrybuty zapisane są w formacie ['kolor', 'kształt', 'masa', 'wielkość', 'nazwa']. Oczywiście przy wyszukiwaniu elementu w drzewie jego nazwa nie jest potrzebna ponieważ to jej szukamy. Przykładowe elementy z zestawu uczącego:

    ['black',  'rectangle', 51,  'small', 'Mars'],
    ['gold',   'pack',      100, 'big',   'Haribo'],
    ['purple', 'rectangle', 100, 'big',   'Milka'],
    ['brown',  'pack',      45,  'small', 'M&M'],

Implementacja w projekcie

Przy rozpoczęciu głównej pętli programu w pliku main.py drzewo my_tree zostaje zbudowane w oparciu o dane data.learning_data.

Gdy program już działa, po wciśnięciu spacji jeden ze słodyczy zostanie losowo wybrany z zestawu data.learning_data oraz umieszczony na polu board[9][0], a jego nazwa zostanie wypisana w konsoli. Następnie Agent przemieszcza się do punktu board[9][0] i rozpoczne procedurę wyszukiwania elementu w zbudowanym drzewie. Na końcu wypisze w konsoli nazwę produktu.