Sztuczna_Inteligencja_2020/MarcinDobrowolski.md
2020-05-27 09:20:38 +02:00

13 KiB

Podprojekt indywidualny - raport

Temat: Sugerowanie potraw Metoda uczenia: Decision Tree(CART) Autor: Marcin Jerzy Dobrowolski Okres trwwania prac: 29 kwietnia - 27 maja 2020

Link do repozytorium projektu: https://git.wmi.amu.edu.pl/s444427/Sztuczna_Inteligencja_2020

Wstęp

Tematem realizowanego projektu jest stworzenie sztucznej inteligencji, która na podstawie podanych atrybutów stwierdzi jaka potrawa będzie najlepszym wyborem dla gościa restauaracji. Atrybuty odpowiadają na takie pytania jak:

  • czy gość jest wegetarianinem? - {meat, vege}
  • jak bardzo chce się najeść/jak wielka ma być potrawa? - {small, regular, large}
  • ile gotówki może przeznaczyć na ten cel? - {<10;20>, <20;30>, <30;40>, <40;50>, <50;60>}
  • jak jaki jest jego apetyt? - {<10;25>, <25;40>, <40;55>}

Powyższymi atrybutami rządzą pewne zależności. Apetyt gościa ma wpływ na to jaka potrawa będzie dla niego mała, średnia bądź duża.

Do osiągnięcia celu wykorzystałem metodę drzew decyzyjnych CART oraz następujące biblioteki języka python:

  • csv
  • random

Uczenie modelu

Zbiór uczący

Do zbudowania drzewa decyzyjnego wykorzystałem spreparowany przez siebie zbiór przykładów atrybutów gości z odpowiadającymi im daniami. Zawarte jest w nim 213 przykładów dotyczących 20 różnych potraw. Całość znajduje się w pliku: trainingData.csv

Moduły

Struktura danych reprezentujących preferencje gościa:

type, size, money, appetite, name
['meat', 'small', 31, 55, 'schabowy']

Opis modułów

W pliku utility znajdują się podręczne funkcje użytkowe wykorzystywane przez pozostałe moduły programu.

def uniqueValues(rows, column):
    return set([row[column] for row in rows])

Znajduje unikalne wartości wyznaczonej kolumny z podanego zbioru danych.

def classCount(rows):
    counts = {}
    for row in rows:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

Podlicza ilość elementów każdej etykiety z danego zbioru danych.

def is_numeric(value):
    return isinstance(value, int) or isinstance(value, float)

Sprawdza czy podana wartość jest liczbą.

def partition(rows, question):
    trueRows, falseRows = [], []
    for row in rows:
        if question.match(row):
            trueRows.append(row)
        else:
            falseRows.append(row)
    return trueRows, falseRows

Dzieli dany zbiór ze względu na zadane pytanie.

def generateTestData(path, quantity):
    with open('path', 'w') as csvFile:
        csvWriter = csv.writer(csvFile, delimiter=',')
        ...

Generuje zbiór danych testowych o zadanej wielkości i zapisuje go w pliku o podanej ścieżce. Ze względu na długość kodu nie zamieściłem go w całości. Dane są generowane przy użyciu liczb pseudo losowych w taki sposób by były zgodne z prawami rządzącymi tym zbiorem danych.

def generateTestExample():
    example = []
    category = random.randrange(0, 1)
    size = random.randrange(0, 2)
    if category == 0:
        example.append('meat')
    ...

Działa na takiej samej zasadzie jak jej poprzedniczka. Z tym wyjątkiem, że generuje pojedynczy przykład.

Klasa Question reprezentuje pytanie, które służy do podziału danych w wierzchołkach decyzyjnych drzewa. Jej atrybuty oznajczają kolejno: etykiete kolumny zbioru danych, indeks wspomnianej kolumny oraz wartość, której dotyczy pytanie.

class Question:
    def __init__(self, columnLabel, column, value):
        self.columnLabel = columnLabel
        self.column = column
        self.value = value

    def match(self, example):
        val = example[self.column]
        if is_numeric(val):
            return val <= self.value
        else:
            return val == self.value

Metoda match wskazuje jak na pytanie odpowiada zadany przykład.

Jądro algorytmu znajduje się w klasie suggestionTree w pilku suggestionDecisionTree. Metoda readTrainingData służy do odczytania zbioru uczącego z pliku trainingData.csv i zwrócenia przechowywanych w nim danych wraz z ich etykietami.

def readTrainingData(path):
        with open(path) as csv_file:
            csvReader = csv.reader(csv_file, delimiter=',')
            lineCount = 0
            trainingData = []
            labels = []

            for row in csvReader:
                example = []
                for column in row:
                    if lineCount == 0:
                        labels.append(column)
                    else:
                        if column.isdigit():
                            example.append(int(column))
                        else:
                            example.append(column)
                if lineCount > 0:
                    trainingData.append(example)
                lineCount += 1

            print('Processed lines: ', lineCount)
            return trainingData, labels

Jest wykorzystana w inicjalizacji zmiennej globalnej trainigData oraz labels, z których korzystają niektóre z metod tej klasy.

trainingData, labels = SuggestionTree.readTrainingData(
    'src/SubprojectMarcinDobrowolski/Data/trainingData.csv')

Metoda gini oblicza tzw. gini impurity. Jest to miara, która mówi jak duże jest prawdopodobieństwo by losowo wybrany element ze zbioru został błędnie oznaczony.

def gini(rows):
    counts = classCount(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

def infoGain(left, right, currentUncertainty):
    p = float(len(left)) / (len(left) + len(right))
    return currentUncertainty - p * SuggestionTree.gini(left) - (1 - p) * SuggestionTree.gini(right)

Natomiast metoda infoGain oblicza w jakim stopniu zmniejsza się wskaźnik gini impurity ze względu na wybór lewego i prawego podziału zbioru.

Metoda findBestSplit znajduje pytanie, które w danej chwili spowoduje jak największy przyrost informacji.

def findBestSplit(rows):
        bestGain = 0
        bestQuestion = None
        currentUncertainty = SuggestionTree.gini(rows)
        nFeatures = len(labels) - 1

        for column in range(nFeatures):

            values = set([row[column] for row in rows])

            for value in values:

                question = Question(labels[column], column, value)

                trueRows, falseRows = partition(rows, question)

                if len(trueRows) == 0 or len(falseRows) == 0:
                    continue

                gain = SuggestionTree.infoGain(
                    trueRows, falseRows, currentUncertainty)

                if gain > bestGain:
                    bestGain, bestQuestion = gain, question

        return bestGain, bestQuestion

Klasa Leaf jest reprezentacją liścia drzewa czyli najbardziej zewnętrznego wierzchołka. W atrybucie predictions przechowuje oszacowanie etykiet/y, które może przyjąć przykład, który do niego trafi. Metoda printLeaf służy do eleganckiego wyświetlania oszacowania.

class Leaf:
    def __init__(self, rows):
        self.predictions = classCount(rows)

    def printLeaf(self):
        total = sum(self.predictions.values()) * 1.0
        probs = {}
        for label in self.predictions.keys():
            probs[label] = str(
                int(self.predictions[label] / total * 100)) + '%'
        return probs

Klasa DecisionNode reprezentuje wierzchołek drzewa, w którym następuje zadanie pytania oraz podział danych. Atrybuty trueBranch oraz falseBranch reprezentują odpowiednio lewego i prawego potomka.

class DecisionNode:
    def __init__(self,
                 question,
                 trueBranch,
                 falseBranch):
        self.question = question
        self.trueBranch = trueBranch
        self.falseBranch = falseBranch

Metoda buildTree odpowiada za rekurencyjne skonstruowanie drzewa na bazie przekazanego zbioru danych. Zwraca korzeń drzewa.

def buildTree(rows):
        gain, question = SuggestionTree.findBestSplit(rows)

        if gain == 0:
            return Leaf(rows)

        trueRows, falseRows = partition(rows, question)

        trueBranch = SuggestionTree.buildTree(trueRows)
        falseBranch = SuggestionTree.buildTree(falseRows)

        return DecisionNode(question, trueBranch, falseBranch)

Metoda printTree wyświetla strukturę drzewa w postaci tekstu w wierszu poleceń.

def printTree(node, spacing=' '):

        if isinstance(node, Leaf):
            print(spacing + 'Predict', node.predictions)
            return

        print(spacing + str(node.question))

        print(spacing + '--> True:')
        SuggestionTree.printTree(node.trueBranch, spacing + ' ')

        print(spacing + '--> False:')
        SuggestionTree.printTree(node.falseBranch, spacing + ' ')

    def classify(row, node):

        if isinstance(node, Leaf):
            return node

        if node.question.match(row):
            return SuggestionTree.classify(row, node.trueBranch)
        else:
            return SuggestionTree.classify(row, node.falseBranch)

Metoda classify przyporządkowuje danemu przykładowi odpowiadającą mu etykiete w drzewie.

Działanie

Drzewo jest budowane na podstawie zbioru uczącego. Algorytm podejmuje decyzje o tym jak podzielić zbiór na podstawie przyrostu informacji wynikającego z potencjalnego podziału. Przerywa działanie gdy przyrost wynosi 0, co oznacza, że dotarł do liścia drzewa i dalszy podział nie jest możliwy. Kończąc swoje działanie zwraca korzeń klasy DecisionNode, który zawiera w atrybutach potomków> Umożliwia to rekurencyjne schodzenie po drzewie, które jest realizowane przez metodę classify.

Integracja z projektem

Deklaracja oraz inicjalizacja korzenia drzewa:

suggestionTreeRoot = SuggestionTree.buildTree(trainingData)

W projekcie można sprawdzić działanie podprojektu poprzez wciśnięcie, któregoś z klawiszy:

  • 0 - sprawdzenie poprawności algorytmu na losowo wygenerowanym zbiorze testowym.

      generateTestData('testData.csv', 100)
      testData = []
      with open('src/SubprojectMarcinDobrowolski/Data/testData.csv') as csv_file:
          csvReader = csv.reader(csv_file, delimiter=',')
          lineCount = 0
          for row in csvReader:
              example = []
              for column in row:
                  if column.isdigit():
                      example.append(int(column))
                  else:
                      example.append(column)
              if lineCount > 0:
                  testData.append(example)
              lineCount += 1
    
          print('Processed lines: ', lineCount)
    
      print('Test examples predictions:')
      for example in testData:
          print('{} - {}'.format(example, SuggestionTree.classify(
              example, suggestionTreeRoot).printLeaf()))
    
  • 1 - sprawdzenie poprawności algorytmu na losowo wygenerowanym przykładzie.

      example = generateTestExample()
      print('Test example prediction: ')
      print('{} - {}'.format(example, SuggestionTree.classify(
          example, suggestionTreeRoot).printLeaf()))
    
  • 2 - wyświetlenie struktury drzewa

      SuggestionTree.printTree(suggestionTreeRoot)
    

Rezultaty zostają wypisane w konsoli.