Sztuczna_Inteligencja_2020/MarcinDobrowolski.md
2020-05-27 09:20:38 +02:00

312 lines
13 KiB
Markdown

# Podprojekt indywidualny - raport
**Temat:** Sugerowanie potraw
**Metoda uczenia:** Decision Tree(CART)
**Autor:** Marcin Jerzy Dobrowolski
**Okres trwwania prac:** 29 kwietnia - 27 maja 2020
**Link do repozytorium projektu:** https://git.wmi.amu.edu.pl/s444427/Sztuczna_Inteligencja_2020
## Wstęp
Tematem realizowanego projektu jest stworzenie sztucznej inteligencji, która na podstawie podanych atrybutów stwierdzi jaka potrawa będzie najlepszym wyborem dla gościa restauaracji. Atrybuty odpowiadają na takie pytania jak:
* czy gość jest wegetarianinem? - {meat, vege}
* jak bardzo chce się najeść/jak wielka ma być potrawa? - {small, regular, large}
* ile gotówki może przeznaczyć na ten cel? - {<10;20>, <20;30>, <30;40>, <40;50>, <50;60>}
* jak jaki jest jego apetyt? - {<10;25>, <25;40>, <40;55>}
Powyższymi atrybutami rządzą pewne zależności. Apetyt gościa ma wpływ na to jaka potrawa będzie dla niego mała, średnia bądź duża.
Do osiągnięcia celu wykorzystałem metodę drzew decyzyjnych CART oraz następujące biblioteki języka python:
* csv
* random
## Uczenie modelu
### Zbiór uczący
Do zbudowania drzewa decyzyjnego wykorzystałem spreparowany przez siebie zbiór przykładów atrybutów gości z odpowiadającymi im daniami. Zawarte jest w nim 213 przykładów dotyczących 20 różnych potraw. Całość znajduje się w pliku: [trainingData.csv](src/SubprojectMarcinDobrowolski/Data/trainingData.csv)
### Moduły
* [suggestionDecisionTree.py](src/SubprojectMarcinDobrowolski/suggestionDecisionTree.py) - klasa drzew decyzyjnego
* [question.py](src/SubprojectMarcinDobrowolski/question.py) - klasa pytania
* [leaf.py](src/SubprojectMarcinDobrowolski/leaf.py) - klasa liścia
* [decisionNode.py](src/SubprojectMarcinDobrowolski/decisionNode.py) - klasa wierzchołka
* [utility.py](src/SubprojectMarcinDobrowolski/utility.py) - klasa użytkowa
* [testData.csv](src/SubprojectMarcinDobrowolski/Data/testData.csv) - zbiór testowy
Struktura danych reprezentujących preferencje gościa:
type, size, money, appetite, name
['meat', 'small', 31, 55, 'schabowy']
### Opis modułów
W pliku *utility* znajdują się podręczne funkcje użytkowe wykorzystywane przez pozostałe moduły programu.
def uniqueValues(rows, column):
return set([row[column] for row in rows])
Znajduje unikalne wartości wyznaczonej kolumny z podanego zbioru danych.
def classCount(rows):
counts = {}
for row in rows:
label = row[-1]
if label not in counts:
counts[label] = 0
counts[label] += 1
return counts
Podlicza ilość elementów każdej etykiety z danego zbioru danych.
def is_numeric(value):
return isinstance(value, int) or isinstance(value, float)
Sprawdza czy podana wartość jest liczbą.
def partition(rows, question):
trueRows, falseRows = [], []
for row in rows:
if question.match(row):
trueRows.append(row)
else:
falseRows.append(row)
return trueRows, falseRows
Dzieli dany zbiór ze względu na zadane pytanie.
def generateTestData(path, quantity):
with open('path', 'w') as csvFile:
csvWriter = csv.writer(csvFile, delimiter=',')
...
Generuje zbiór danych testowych o zadanej wielkości i zapisuje go w pliku o podanej ścieżce. Ze względu na długość kodu nie zamieściłem go w całości. Dane są generowane przy użyciu liczb pseudo losowych w taki sposób by były zgodne z prawami rządzącymi tym zbiorem danych.
def generateTestExample():
example = []
category = random.randrange(0, 1)
size = random.randrange(0, 2)
if category == 0:
example.append('meat')
...
Działa na takiej samej zasadzie jak jej poprzedniczka. Z tym wyjątkiem, że generuje pojedynczy przykład.
Klasa *Question* reprezentuje pytanie, które służy do podziału danych w wierzchołkach decyzyjnych drzewa. Jej atrybuty oznajczają kolejno: etykiete kolumny zbioru danych, indeks wspomnianej kolumny oraz wartość, której dotyczy pytanie.
class Question:
def __init__(self, columnLabel, column, value):
self.columnLabel = columnLabel
self.column = column
self.value = value
def match(self, example):
val = example[self.column]
if is_numeric(val):
return val <= self.value
else:
return val == self.value
Metoda match wskazuje jak na pytanie odpowiada zadany przykład.
Jądro algorytmu znajduje się w klasie *suggestionTree* w pilku *suggestionDecisionTree*. Metoda *readTrainingData* służy do odczytania zbioru uczącego z pliku *trainingData.csv* i zwrócenia przechowywanych w nim danych wraz z ich etykietami.
def readTrainingData(path):
with open(path) as csv_file:
csvReader = csv.reader(csv_file, delimiter=',')
lineCount = 0
trainingData = []
labels = []
for row in csvReader:
example = []
for column in row:
if lineCount == 0:
labels.append(column)
else:
if column.isdigit():
example.append(int(column))
else:
example.append(column)
if lineCount > 0:
trainingData.append(example)
lineCount += 1
print('Processed lines: ', lineCount)
return trainingData, labels
Jest wykorzystana w inicjalizacji zmiennej globalnej *trainigData* oraz *labels*, z których korzystają niektóre z metod tej klasy.
```
trainingData, labels = SuggestionTree.readTrainingData(
'src/SubprojectMarcinDobrowolski/Data/trainingData.csv')
```
Metoda *gini* oblicza tzw. *gini impurity*. Jest to miara, która mówi jak duże jest prawdopodobieństwo by losowo wybrany element ze zbioru został błędnie oznaczony.
def gini(rows):
counts = classCount(rows)
impurity = 1
for lbl in counts:
prob_of_lbl = counts[lbl] / float(len(rows))
impurity -= prob_of_lbl**2
return impurity
def infoGain(left, right, currentUncertainty):
p = float(len(left)) / (len(left) + len(right))
return currentUncertainty - p * SuggestionTree.gini(left) - (1 - p) * SuggestionTree.gini(right)
Natomiast metoda *infoGain* oblicza w jakim stopniu zmniejsza się wskaźnik *gini impurity* ze względu na wybór lewego i prawego podziału zbioru.
Metoda *findBestSplit* znajduje pytanie, które w danej chwili spowoduje jak największy przyrost informacji.
def findBestSplit(rows):
bestGain = 0
bestQuestion = None
currentUncertainty = SuggestionTree.gini(rows)
nFeatures = len(labels) - 1
for column in range(nFeatures):
values = set([row[column] for row in rows])
for value in values:
question = Question(labels[column], column, value)
trueRows, falseRows = partition(rows, question)
if len(trueRows) == 0 or len(falseRows) == 0:
continue
gain = SuggestionTree.infoGain(
trueRows, falseRows, currentUncertainty)
if gain > bestGain:
bestGain, bestQuestion = gain, question
return bestGain, bestQuestion
Klasa *Leaf* jest reprezentacją liścia drzewa czyli najbardziej zewnętrznego wierzchołka. W atrybucie *predictions* przechowuje oszacowanie etykiet/y, które może przyjąć przykład, który do niego trafi. Metoda *printLeaf* służy do eleganckiego wyświetlania oszacowania.
class Leaf:
def __init__(self, rows):
self.predictions = classCount(rows)
def printLeaf(self):
total = sum(self.predictions.values()) * 1.0
probs = {}
for label in self.predictions.keys():
probs[label] = str(
int(self.predictions[label] / total * 100)) + '%'
return probs
Klasa *DecisionNode* reprezentuje wierzchołek drzewa, w którym następuje zadanie pytania oraz podział danych. Atrybuty *trueBranch* oraz *falseBranch* reprezentują odpowiednio lewego i prawego potomka.
class DecisionNode:
def __init__(self,
question,
trueBranch,
falseBranch):
self.question = question
self.trueBranch = trueBranch
self.falseBranch = falseBranch
Metoda *buildTree* odpowiada za rekurencyjne skonstruowanie drzewa na bazie przekazanego zbioru danych. Zwraca korzeń drzewa.
def buildTree(rows):
gain, question = SuggestionTree.findBestSplit(rows)
if gain == 0:
return Leaf(rows)
trueRows, falseRows = partition(rows, question)
trueBranch = SuggestionTree.buildTree(trueRows)
falseBranch = SuggestionTree.buildTree(falseRows)
return DecisionNode(question, trueBranch, falseBranch)
Metoda *printTree* wyświetla strukturę drzewa w postaci tekstu w wierszu poleceń.
def printTree(node, spacing=' '):
if isinstance(node, Leaf):
print(spacing + 'Predict', node.predictions)
return
print(spacing + str(node.question))
print(spacing + '--> True:')
SuggestionTree.printTree(node.trueBranch, spacing + ' ')
print(spacing + '--> False:')
SuggestionTree.printTree(node.falseBranch, spacing + ' ')
def classify(row, node):
if isinstance(node, Leaf):
return node
if node.question.match(row):
return SuggestionTree.classify(row, node.trueBranch)
else:
return SuggestionTree.classify(row, node.falseBranch)
Metoda *classify* przyporządkowuje danemu przykładowi odpowiadającą mu etykiete w drzewie.
### Działanie
Drzewo jest budowane na podstawie zbioru uczącego. Algorytm podejmuje decyzje o tym jak podzielić zbiór na podstawie przyrostu informacji wynikającego z potencjalnego podziału. Przerywa działanie gdy przyrost wynosi *0*, co oznacza, że dotarł do liścia drzewa i dalszy podział nie jest możliwy. Kończąc swoje działanie zwraca korzeń klasy *DecisionNode*, który zawiera w atrybutach potomków> Umożliwia to rekurencyjne schodzenie po drzewie, które jest realizowane przez metodę *classify*.
### Integracja z projektem
Deklaracja oraz inicjalizacja korzenia drzewa:
suggestionTreeRoot = SuggestionTree.buildTree(trainingData)
W projekcie można sprawdzić działanie podprojektu poprzez wciśnięcie, któregoś z klawiszy:
* 0 - sprawdzenie poprawności algorytmu na losowo wygenerowanym zbiorze testowym.
generateTestData('testData.csv', 100)
testData = []
with open('src/SubprojectMarcinDobrowolski/Data/testData.csv') as csv_file:
csvReader = csv.reader(csv_file, delimiter=',')
lineCount = 0
for row in csvReader:
example = []
for column in row:
if column.isdigit():
example.append(int(column))
else:
example.append(column)
if lineCount > 0:
testData.append(example)
lineCount += 1
print('Processed lines: ', lineCount)
print('Test examples predictions:')
for example in testData:
print('{} - {}'.format(example, SuggestionTree.classify(
example, suggestionTreeRoot).printLeaf()))
* 1 - sprawdzenie poprawności algorytmu na losowo wygenerowanym przykładzie.
example = generateTestExample()
print('Test example prediction: ')
print('{} - {}'.format(example, SuggestionTree.classify(
example, suggestionTreeRoot).printLeaf()))
* 2 - wyświetlenie struktury drzewa
SuggestionTree.printTree(suggestionTreeRoot)
Rezultaty zostają wypisane w konsoli.