2
0
forked from s444420/AL-2020
AL-2020/Raporty/raport_444428.md

5.0 KiB

Wojciech Lukasik - drzewa decyzyjne, algorytm CART

Opis podprojektu

Podprojekt implementuje tworzenie drzewa decyzyjnego w oparciu o algorytm CART (Classification And Regression Tree), które pomaga Agentowi w rozpoznaniu słodyczy na podstawie ich cech fizycznych (kolor, kształt, masa, rozmiar).

Wszystkie funkcje oraz klasy wykorzystywane w tym podprojekcie znajdują się w pliku decision_tree.py, dane uczące znajdują się w pliku data.py w liście learning_data

Tworzenie drzewa decyzyjnego

Główną funkcją jest build_tree(rows), która jak wskazuje nazwa tworzy drzewo. Funkcja przyjmuje jako argument listę zawierającą zestaw danych, w tym przypadku będą to słodycze o różnych właściwościach.

def build_tree(rows):
    gain, question = find_best_split(rows)

    if gain == 0:
        return Leaf(rows)

    true_rows, false_rows = partition(rows, question)

    true_branch = build_tree(true_rows)

    false_branch = build_tree(false_rows)

    return DecisionNode(question, true_branch, false_branch)

Drzewo jest budowane w oparciu o najlepsze możlwe podziały (najbardziej korzystne 'pytanie', które można zadać). Zajmuje się tym funkcja

find_best_split(rows) która dla wszystkich właściwości przekazanego zestawu informacji wylicza dla nich 'zysk informacji'.

Jeżeli nie otrzymujemy żadnych informacji (gain == 0) to znaczy, że znajdujemy się w liściu drzewa.

def find_best_split(rows):
    """ znajdź najlepsze możliwe pytanie do zadania, sprawdzając wszystkie
        właściwośći oraz licząc dla nich 'info_gain' """
    best_gain = 0
    best_question = None
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1

    for col in range(n_features):
        values = set([row[col] for row in rows])

        for val in values:
            question = Question(col, val)

            true_rows, false_rows = partition(rows, question)

            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            gain = info_gain(true_rows, false_rows, current_uncertainty)

            if gain > best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question

Zysk informacji z danego podziału otrzymujemy obliczając wartość 'Gini Impurity'. Jest to miara tego jak często losowo wybrany element zbioru byłby źle skategoryzowany, gdyby przypisać mu losową kategorię spośród wszystkich kategorii znajdujących się w danym zbiorze.

def gini(rows):
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl ** 2
    return impurity

class_counts(rows) to funkcja, która dla danego zestawu danych zwraca wszystkie unikalne klasy oraz liczbę ich wystąpień.

Dla przykładu, dla zestawu w którym wszystkie elementy podchodzą pod tę samą kategorię wartość Gini będzie równa zero, natomiast dla zbioru w którym znajdują się dwie kategorie wartość ta wyniesie 0,5.

Po znalezieniu najbardziej optymalnego pytania, algorytm dzieli zestaw na elementy, dla których pytanie jest prawdziwe (true_rows), oraz te dla których jest fałszywe (false_rows). Następnie wykonuje rekurencyjnie procedurę build_tree dla obu poddrzew tak długo aż nie dojdzie do liści.

Element o zadanym zestawie cech, zostaje odnaleziony w drzewie dzięki prostej procedurze

classify(row, node) 'row' to lista cech elementu, natomiast 'node' na początu jest korzeniem już zbudowanego drzewa.

Element jest odnaleziony dzięki rekurencyjnym porównaniom atrybutów elementu z pytaniami w kolejnych węzłach drzewa.

def classify(row, node):
    if isinstance(node, Leaf):
        return node.predicions

    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

Zestaw uczący

Zestaw budujący drzewo to lista zawierająca 27 przykładowych słodyczy. Ich atrybuty zapisane są w formacie ['kolor', 'kształt', 'masa', 'wielkość', 'nazwa']. Oczywiście przy wyszukiwaniu elementu w drzewie jego nazwa nie jest potrzebna ponieważ to jej szukamy. Przykładowe elementy z zestawu uczącego:

    ['black',  'rectangle', 51,  'small', 'Mars'],
    ['gold',   'pack',      100, 'big',   'Haribo'],
    ['purple', 'rectangle', 100, 'big',   'Milka'],
    ['brown',  'pack',      45,  'small', 'M&M'],

Implementacja w projekcie

Przy rozpoczęciu głównej pętli programu w pliku main.py drzewo my_tree zostaje zbudowane w oparciu o dane data.learning_data.

Gdy program już działa, po wciśnięciu spacji jeden ze słodyczy zostanie losowo wybrany z zestawu data.learning_data oraz umieszczony na polu board[9][0], a jego nazwa zostanie wypisana w konsoli. Następnie Agent przemieszcza się do punktu board[9][0] i rozpoczne procedurę wyszukiwania elementu w zbudowanym drzewie. Na końcu wypisze w konsoli nazwę produktu.