Marcin Dobrowolski razport z podprojketu
This commit is contained in:
parent
68104e25c6
commit
d89e37e6eb
312
MarcinDobrowolski.md
Normal file
312
MarcinDobrowolski.md
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
# Podprojekt indywidualny - raport
|
||||||
|
|
||||||
|
|
||||||
|
**Temat:** Sugerowanie potraw
|
||||||
|
**Metoda uczenia:** Decision Tree(CART)
|
||||||
|
**Autor:** Marcin Jerzy Dobrowolski
|
||||||
|
**Okres trwwania prac:** 29 kwietnia - 27 maja 2020
|
||||||
|
|
||||||
|
**Link do repozytorium projektu:** https://git.wmi.amu.edu.pl/s444427/Sztuczna_Inteligencja_2020
|
||||||
|
|
||||||
|
## Wstęp
|
||||||
|
|
||||||
|
Tematem realizowanego projektu jest stworzenie sztucznej inteligencji, która na podstawie podanych atrybutów stwierdzi jaka potrawa będzie najlepszym wyborem dla gościa restauaracji. Atrybuty odpowiadają na takie pytania jak:
|
||||||
|
|
||||||
|
* czy gość jest wegetarianinem? - {meat, vege}
|
||||||
|
* jak bardzo chce się najeść/jak wielka ma być potrawa? - {small, regular, large}
|
||||||
|
* ile gotówki może przeznaczyć na ten cel? - {<10;20>, <20;30>, <30;40>, <40;50>, <50;60>}
|
||||||
|
* jak jaki jest jego apetyt? - {<10;25>, <25;40>, <40;55>}
|
||||||
|
|
||||||
|
Powyższymi atrybutami rządzą pewne zależności. Apetyt gościa ma wpływ na to jaka potrawa będzie dla niego mała, średnia bądź duża.
|
||||||
|
|
||||||
|
Do osiągnięcia celu wykorzystałem metodę drzew decyzyjnych CART oraz następujące biblioteki języka python:
|
||||||
|
|
||||||
|
* csv
|
||||||
|
* random
|
||||||
|
|
||||||
|
## Uczenie modelu
|
||||||
|
|
||||||
|
### Zbiór uczący
|
||||||
|
|
||||||
|
Do zbudowania drzewa decyzyjnego wykorzystałem spreparowany przez siebie zbiór przykładów atrybutów gości z odpowiadającymi im daniami. Zawarte jest w nim 213 przykładów dotyczących 20 różnych potraw. Całość znajduje się w pliku: [trainingData.csv](src/SubprojectMarcinDobrowolski/Data/trainingData.csv)
|
||||||
|
|
||||||
|
### Moduły
|
||||||
|
|
||||||
|
* [suggestionDecisionTree.py](src/SubprojectMarcinDobrowolski/suggestionDecisionTree.py) - klasa drzew decyzyjnego
|
||||||
|
* [question.py](src/SubprojectMarcinDobrowolski/question.py) - klasa pytania
|
||||||
|
* [leaf.py](src/SubprojectMarcinDobrowolski/leaf.py) - klasa liścia
|
||||||
|
* [decisionNode.py](src/SubprojectMarcinDobrowolski/decisionNode.py) - klasa wierzchołka
|
||||||
|
* [utility.py](src/SubprojectMarcinDobrowolski/utility.py) - klasa użytkowa
|
||||||
|
* [testData.csv](src/SubprojectMarcinDobrowolski/Data/testData.csv) - zbiór testowy
|
||||||
|
|
||||||
|
Struktura danych reprezentujących preferencje gościa:
|
||||||
|
|
||||||
|
type, size, money, appetite, name
|
||||||
|
['meat', 'small', 31, 55, 'schabowy']
|
||||||
|
|
||||||
|
### Opis modułów
|
||||||
|
|
||||||
|
W pliku *utility* znajdują się podręczne funkcje użytkowe wykorzystywane przez pozostałe moduły programu.
|
||||||
|
|
||||||
|
def uniqueValues(rows, column):
|
||||||
|
return set([row[column] for row in rows])
|
||||||
|
|
||||||
|
Znajduje unikalne wartości wyznaczonej kolumny z podanego zbioru danych.
|
||||||
|
|
||||||
|
def classCount(rows):
|
||||||
|
counts = {}
|
||||||
|
for row in rows:
|
||||||
|
label = row[-1]
|
||||||
|
if label not in counts:
|
||||||
|
counts[label] = 0
|
||||||
|
counts[label] += 1
|
||||||
|
return counts
|
||||||
|
|
||||||
|
Podlicza ilość elementów każdej etykiety z danego zbioru danych.
|
||||||
|
|
||||||
|
def is_numeric(value):
|
||||||
|
return isinstance(value, int) or isinstance(value, float)
|
||||||
|
|
||||||
|
Sprawdza czy podana wartość jest liczbą.
|
||||||
|
|
||||||
|
def partition(rows, question):
|
||||||
|
trueRows, falseRows = [], []
|
||||||
|
for row in rows:
|
||||||
|
if question.match(row):
|
||||||
|
trueRows.append(row)
|
||||||
|
else:
|
||||||
|
falseRows.append(row)
|
||||||
|
return trueRows, falseRows
|
||||||
|
|
||||||
|
Dzieli dany zbiór ze względu na zadane pytanie.
|
||||||
|
|
||||||
|
def generateTestData(path, quantity):
|
||||||
|
with open('path', 'w') as csvFile:
|
||||||
|
csvWriter = csv.writer(csvFile, delimiter=',')
|
||||||
|
...
|
||||||
|
|
||||||
|
Generuje zbiór danych testowych o zadanej wielkości i zapisuje go w pliku o podanej ścieżce. Ze względu na długość kodu nie zamieściłem go w całości. Dane są generowane przy użyciu liczb pseudo losowych w taki sposób by były zgodne z prawami rządzącymi tym zbiorem danych.
|
||||||
|
|
||||||
|
def generateTestExample():
|
||||||
|
example = []
|
||||||
|
category = random.randrange(0, 1)
|
||||||
|
size = random.randrange(0, 2)
|
||||||
|
if category == 0:
|
||||||
|
example.append('meat')
|
||||||
|
...
|
||||||
|
|
||||||
|
Działa na takiej samej zasadzie jak jej poprzedniczka. Z tym wyjątkiem, że generuje pojedynczy przykład.
|
||||||
|
|
||||||
|
Klasa *Question* reprezentuje pytanie, które służy do podziału danych w wierzchołkach decyzyjnych drzewa. Jej atrybuty oznajczają kolejno: etykiete kolumny zbioru danych, indeks wspomnianej kolumny oraz wartość, której dotyczy pytanie.
|
||||||
|
|
||||||
|
class Question:
|
||||||
|
def __init__(self, columnLabel, column, value):
|
||||||
|
self.columnLabel = columnLabel
|
||||||
|
self.column = column
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def match(self, example):
|
||||||
|
val = example[self.column]
|
||||||
|
if is_numeric(val):
|
||||||
|
return val <= self.value
|
||||||
|
else:
|
||||||
|
return val == self.value
|
||||||
|
|
||||||
|
Metoda match wskazuje jak na pytanie odpowiada zadany przykład.
|
||||||
|
|
||||||
|
Jądro algorytmu znajduje się w klasie *suggestionTree* w pilku *suggestionDecisionTree*. Metoda *readTrainingData* służy do odczytania zbioru uczącego z pliku *trainingData.csv* i zwrócenia przechowywanych w nim danych wraz z ich etykietami.
|
||||||
|
|
||||||
|
|
||||||
|
def readTrainingData(path):
|
||||||
|
with open(path) as csv_file:
|
||||||
|
csvReader = csv.reader(csv_file, delimiter=',')
|
||||||
|
lineCount = 0
|
||||||
|
trainingData = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for row in csvReader:
|
||||||
|
example = []
|
||||||
|
for column in row:
|
||||||
|
if lineCount == 0:
|
||||||
|
labels.append(column)
|
||||||
|
else:
|
||||||
|
if column.isdigit():
|
||||||
|
example.append(int(column))
|
||||||
|
else:
|
||||||
|
example.append(column)
|
||||||
|
if lineCount > 0:
|
||||||
|
trainingData.append(example)
|
||||||
|
lineCount += 1
|
||||||
|
|
||||||
|
print('Processed lines: ', lineCount)
|
||||||
|
return trainingData, labels
|
||||||
|
|
||||||
|
Jest wykorzystana w inicjalizacji zmiennej globalnej *trainigData* oraz *labels*, z których korzystają niektóre z metod tej klasy.
|
||||||
|
|
||||||
|
```
|
||||||
|
trainingData, labels = SuggestionTree.readTrainingData(
|
||||||
|
'src/SubprojectMarcinDobrowolski/Data/trainingData.csv')
|
||||||
|
```
|
||||||
|
|
||||||
|
Metoda *gini* oblicza tzw. *gini impurity*. Jest to miara, która mówi jak duże jest prawdopodobieństwo by losowo wybrany element ze zbioru został błędnie oznaczony.
|
||||||
|
|
||||||
|
def gini(rows):
|
||||||
|
counts = classCount(rows)
|
||||||
|
impurity = 1
|
||||||
|
for lbl in counts:
|
||||||
|
prob_of_lbl = counts[lbl] / float(len(rows))
|
||||||
|
impurity -= prob_of_lbl**2
|
||||||
|
return impurity
|
||||||
|
|
||||||
|
def infoGain(left, right, currentUncertainty):
|
||||||
|
p = float(len(left)) / (len(left) + len(right))
|
||||||
|
return currentUncertainty - p * SuggestionTree.gini(left) - (1 - p) * SuggestionTree.gini(right)
|
||||||
|
|
||||||
|
Natomiast metoda *infoGain* oblicza w jakim stopniu zmniejsza się wskaźnik *gini impurity* ze względu na wybór lewego i prawego podziału zbioru.
|
||||||
|
|
||||||
|
Metoda *findBestSplit* znajduje pytanie, które w danej chwili spowoduje jak największy przyrost informacji.
|
||||||
|
|
||||||
|
def findBestSplit(rows):
|
||||||
|
bestGain = 0
|
||||||
|
bestQuestion = None
|
||||||
|
currentUncertainty = SuggestionTree.gini(rows)
|
||||||
|
nFeatures = len(labels) - 1
|
||||||
|
|
||||||
|
for column in range(nFeatures):
|
||||||
|
|
||||||
|
values = set([row[column] for row in rows])
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
|
||||||
|
question = Question(labels[column], column, value)
|
||||||
|
|
||||||
|
trueRows, falseRows = partition(rows, question)
|
||||||
|
|
||||||
|
if len(trueRows) == 0 or len(falseRows) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gain = SuggestionTree.infoGain(
|
||||||
|
trueRows, falseRows, currentUncertainty)
|
||||||
|
|
||||||
|
if gain > bestGain:
|
||||||
|
bestGain, bestQuestion = gain, question
|
||||||
|
|
||||||
|
return bestGain, bestQuestion
|
||||||
|
|
||||||
|
Klasa *Leaf* jest reprezentacją liścia drzewa czyli najbardziej zewnętrznego wierzchołka. W atrybucie *predictions* przechowuje oszacowanie etykiet/y, które może przyjąć przykład, który do niego trafi. Metoda *printLeaf* służy do eleganckiego wyświetlania oszacowania.
|
||||||
|
|
||||||
|
class Leaf:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self.predictions = classCount(rows)
|
||||||
|
|
||||||
|
def printLeaf(self):
|
||||||
|
total = sum(self.predictions.values()) * 1.0
|
||||||
|
probs = {}
|
||||||
|
for label in self.predictions.keys():
|
||||||
|
probs[label] = str(
|
||||||
|
int(self.predictions[label] / total * 100)) + '%'
|
||||||
|
return probs
|
||||||
|
|
||||||
|
Klasa *DecisionNode* reprezentuje wierzchołek drzewa, w którym następuje zadanie pytania oraz podział danych. Atrybuty *trueBranch* oraz *falseBranch* reprezentują odpowiednio lewego i prawego potomka.
|
||||||
|
|
||||||
|
class DecisionNode:
|
||||||
|
def __init__(self,
|
||||||
|
question,
|
||||||
|
trueBranch,
|
||||||
|
falseBranch):
|
||||||
|
self.question = question
|
||||||
|
self.trueBranch = trueBranch
|
||||||
|
self.falseBranch = falseBranch
|
||||||
|
|
||||||
|
Metoda *buildTree* odpowiada za rekurencyjne skonstruowanie drzewa na bazie przekazanego zbioru danych. Zwraca korzeń drzewa.
|
||||||
|
|
||||||
|
def buildTree(rows):
|
||||||
|
gain, question = SuggestionTree.findBestSplit(rows)
|
||||||
|
|
||||||
|
if gain == 0:
|
||||||
|
return Leaf(rows)
|
||||||
|
|
||||||
|
trueRows, falseRows = partition(rows, question)
|
||||||
|
|
||||||
|
trueBranch = SuggestionTree.buildTree(trueRows)
|
||||||
|
falseBranch = SuggestionTree.buildTree(falseRows)
|
||||||
|
|
||||||
|
return DecisionNode(question, trueBranch, falseBranch)
|
||||||
|
|
||||||
|
Metoda *printTree* wyświetla strukturę drzewa w postaci tekstu w wierszu poleceń.
|
||||||
|
|
||||||
|
def printTree(node, spacing=' '):
|
||||||
|
|
||||||
|
if isinstance(node, Leaf):
|
||||||
|
print(spacing + 'Predict', node.predictions)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(spacing + str(node.question))
|
||||||
|
|
||||||
|
print(spacing + '--> True:')
|
||||||
|
SuggestionTree.printTree(node.trueBranch, spacing + ' ')
|
||||||
|
|
||||||
|
print(spacing + '--> False:')
|
||||||
|
SuggestionTree.printTree(node.falseBranch, spacing + ' ')
|
||||||
|
|
||||||
|
def classify(row, node):
|
||||||
|
|
||||||
|
if isinstance(node, Leaf):
|
||||||
|
return node
|
||||||
|
|
||||||
|
if node.question.match(row):
|
||||||
|
return SuggestionTree.classify(row, node.trueBranch)
|
||||||
|
else:
|
||||||
|
return SuggestionTree.classify(row, node.falseBranch)
|
||||||
|
|
||||||
|
Metoda *classify* przyporządkowuje danemu przykładowi odpowiadającą mu etykiete w drzewie.
|
||||||
|
|
||||||
|
### Działanie
|
||||||
|
|
||||||
|
Drzewo jest budowane na podstawie zbioru uczącego. Algorytm podejmuje decyzje o tym jak podzielić zbiór na podstawie przyrostu informacji wynikającego z potencjalnego podziału. Przerywa działanie gdy przyrost wynosi *0*, co oznacza, że dotarł do liścia drzewa i dalszy podział nie jest możliwy. Kończąc swoje działanie zwraca korzeń klasy *DecisionNode*, który zawiera w atrybutach potomków> Umożliwia to rekurencyjne schodzenie po drzewie, które jest realizowane przez metodę *classify*.
|
||||||
|
|
||||||
|
### Integracja z projektem
|
||||||
|
|
||||||
|
Deklaracja oraz inicjalizacja korzenia drzewa:
|
||||||
|
|
||||||
|
suggestionTreeRoot = SuggestionTree.buildTree(trainingData)
|
||||||
|
|
||||||
|
W projekcie można sprawdzić działanie podprojektu poprzez wciśnięcie, któregoś z klawiszy:
|
||||||
|
|
||||||
|
* 0 - sprawdzenie poprawności algorytmu na losowo wygenerowanym zbiorze testowym.
|
||||||
|
|
||||||
|
generateTestData('testData.csv', 100)
|
||||||
|
testData = []
|
||||||
|
with open('src/SubprojectMarcinDobrowolski/Data/testData.csv') as csv_file:
|
||||||
|
csvReader = csv.reader(csv_file, delimiter=',')
|
||||||
|
lineCount = 0
|
||||||
|
for row in csvReader:
|
||||||
|
example = []
|
||||||
|
for column in row:
|
||||||
|
if column.isdigit():
|
||||||
|
example.append(int(column))
|
||||||
|
else:
|
||||||
|
example.append(column)
|
||||||
|
if lineCount > 0:
|
||||||
|
testData.append(example)
|
||||||
|
lineCount += 1
|
||||||
|
|
||||||
|
print('Processed lines: ', lineCount)
|
||||||
|
|
||||||
|
print('Test examples predictions:')
|
||||||
|
for example in testData:
|
||||||
|
print('{} - {}'.format(example, SuggestionTree.classify(
|
||||||
|
example, suggestionTreeRoot).printLeaf()))
|
||||||
|
|
||||||
|
* 1 - sprawdzenie poprawności algorytmu na losowo wygenerowanym przykładzie.
|
||||||
|
|
||||||
|
example = generateTestExample()
|
||||||
|
print('Test example prediction: ')
|
||||||
|
print('{} - {}'.format(example, SuggestionTree.classify(
|
||||||
|
example, suggestionTreeRoot).printLeaf()))
|
||||||
|
|
||||||
|
* 2 - wyświetlenie struktury drzewa
|
||||||
|
|
||||||
|
SuggestionTree.printTree(suggestionTreeRoot)
|
||||||
|
|
||||||
|
Rezultaty zostają wypisane w konsoli.
|
2
main.py
2
main.py
@ -65,7 +65,7 @@ if __name__ == "__main__":
|
|||||||
if [waiter.X, waiter.Y] in tabPos:
|
if [waiter.X, waiter.Y] in tabPos:
|
||||||
model = 'waiter_' + waiter.direction
|
model = 'waiter_' + waiter.direction
|
||||||
for x in range(-1, 2):
|
for x in range(-1, 2):
|
||||||
waiterX = waiter.X+(x*0.1)
|
waiterX = waiter.X + (x * 0.1)
|
||||||
print(waiterX)
|
print(waiterX)
|
||||||
graphics.clear(waiterX, waiter.Y - 1)
|
graphics.clear(waiterX, waiter.Y - 1)
|
||||||
rand = randrange(9)
|
rand = randrange(9)
|
||||||
|
Loading…
Reference in New Issue
Block a user