Marcin Dobrowolski razport z podprojketu
This commit is contained in:
parent
68104e25c6
commit
d89e37e6eb
312
MarcinDobrowolski.md
Normal file
312
MarcinDobrowolski.md
Normal file
@ -0,0 +1,312 @@
|
||||
# Podprojekt indywidualny - raport
|
||||
|
||||
|
||||
**Temat:** Sugerowanie potraw
|
||||
**Metoda uczenia:** Decision Tree(CART)
|
||||
**Autor:** Marcin Jerzy Dobrowolski
|
||||
**Okres trwwania prac:** 29 kwietnia - 27 maja 2020
|
||||
|
||||
**Link do repozytorium projektu:** https://git.wmi.amu.edu.pl/s444427/Sztuczna_Inteligencja_2020
|
||||
|
||||
## Wstęp
|
||||
|
||||
Tematem realizowanego projektu jest stworzenie sztucznej inteligencji, która na podstawie podanych atrybutów stwierdzi jaka potrawa będzie najlepszym wyborem dla gościa restauaracji. Atrybuty odpowiadają na takie pytania jak:
|
||||
|
||||
* czy gość jest wegetarianinem? - {meat, vege}
|
||||
* jak bardzo chce się najeść/jak wielka ma być potrawa? - {small, regular, large}
|
||||
* ile gotówki może przeznaczyć na ten cel? - {<10;20>, <20;30>, <30;40>, <40;50>, <50;60>}
|
||||
* jak jaki jest jego apetyt? - {<10;25>, <25;40>, <40;55>}
|
||||
|
||||
Powyższymi atrybutami rządzą pewne zależności. Apetyt gościa ma wpływ na to jaka potrawa będzie dla niego mała, średnia bądź duża.
|
||||
|
||||
Do osiągnięcia celu wykorzystałem metodę drzew decyzyjnych CART oraz następujące biblioteki języka python:
|
||||
|
||||
* csv
|
||||
* random
|
||||
|
||||
## Uczenie modelu
|
||||
|
||||
### Zbiór uczący
|
||||
|
||||
Do zbudowania drzewa decyzyjnego wykorzystałem spreparowany przez siebie zbiór przykładów atrybutów gości z odpowiadającymi im daniami. Zawarte jest w nim 213 przykładów dotyczących 20 różnych potraw. Całość znajduje się w pliku: [trainingData.csv](src/SubprojectMarcinDobrowolski/Data/trainingData.csv)
|
||||
|
||||
### Moduły
|
||||
|
||||
* [suggestionDecisionTree.py](src/SubprojectMarcinDobrowolski/suggestionDecisionTree.py) - klasa drzew decyzyjnego
|
||||
* [question.py](src/SubprojectMarcinDobrowolski/question.py) - klasa pytania
|
||||
* [leaf.py](src/SubprojectMarcinDobrowolski/leaf.py) - klasa liścia
|
||||
* [decisionNode.py](src/SubprojectMarcinDobrowolski/decisionNode.py) - klasa wierzchołka
|
||||
* [utility.py](src/SubprojectMarcinDobrowolski/utility.py) - klasa użytkowa
|
||||
* [testData.csv](src/SubprojectMarcinDobrowolski/Data/testData.csv) - zbiór testowy
|
||||
|
||||
Struktura danych reprezentujących preferencje gościa:
|
||||
|
||||
type, size, money, appetite, name
|
||||
['meat', 'small', 31, 55, 'schabowy']
|
||||
|
||||
### Opis modułów
|
||||
|
||||
W pliku *utility* znajdują się podręczne funkcje użytkowe wykorzystywane przez pozostałe moduły programu.
|
||||
|
||||
def uniqueValues(rows, column):
|
||||
return set([row[column] for row in rows])
|
||||
|
||||
Znajduje unikalne wartości wyznaczonej kolumny z podanego zbioru danych.
|
||||
|
||||
def classCount(rows):
|
||||
counts = {}
|
||||
for row in rows:
|
||||
label = row[-1]
|
||||
if label not in counts:
|
||||
counts[label] = 0
|
||||
counts[label] += 1
|
||||
return counts
|
||||
|
||||
Podlicza ilość elementów każdej etykiety z danego zbioru danych.
|
||||
|
||||
def is_numeric(value):
|
||||
return isinstance(value, int) or isinstance(value, float)
|
||||
|
||||
Sprawdza czy podana wartość jest liczbą.
|
||||
|
||||
def partition(rows, question):
|
||||
trueRows, falseRows = [], []
|
||||
for row in rows:
|
||||
if question.match(row):
|
||||
trueRows.append(row)
|
||||
else:
|
||||
falseRows.append(row)
|
||||
return trueRows, falseRows
|
||||
|
||||
Dzieli dany zbiór ze względu na zadane pytanie.
|
||||
|
||||
def generateTestData(path, quantity):
|
||||
with open('path', 'w') as csvFile:
|
||||
csvWriter = csv.writer(csvFile, delimiter=',')
|
||||
...
|
||||
|
||||
Generuje zbiór danych testowych o zadanej wielkości i zapisuje go w pliku o podanej ścieżce. Ze względu na długość kodu nie zamieściłem go w całości. Dane są generowane przy użyciu liczb pseudo losowych w taki sposób by były zgodne z prawami rządzącymi tym zbiorem danych.
|
||||
|
||||
def generateTestExample():
|
||||
example = []
|
||||
category = random.randrange(0, 1)
|
||||
size = random.randrange(0, 2)
|
||||
if category == 0:
|
||||
example.append('meat')
|
||||
...
|
||||
|
||||
Działa na takiej samej zasadzie jak jej poprzedniczka. Z tym wyjątkiem, że generuje pojedynczy przykład.
|
||||
|
||||
Klasa *Question* reprezentuje pytanie, które służy do podziału danych w wierzchołkach decyzyjnych drzewa. Jej atrybuty oznajczają kolejno: etykiete kolumny zbioru danych, indeks wspomnianej kolumny oraz wartość, której dotyczy pytanie.
|
||||
|
||||
class Question:
|
||||
def __init__(self, columnLabel, column, value):
|
||||
self.columnLabel = columnLabel
|
||||
self.column = column
|
||||
self.value = value
|
||||
|
||||
def match(self, example):
|
||||
val = example[self.column]
|
||||
if is_numeric(val):
|
||||
return val <= self.value
|
||||
else:
|
||||
return val == self.value
|
||||
|
||||
Metoda match wskazuje jak na pytanie odpowiada zadany przykład.
|
||||
|
||||
Jądro algorytmu znajduje się w klasie *suggestionTree* w pilku *suggestionDecisionTree*. Metoda *readTrainingData* służy do odczytania zbioru uczącego z pliku *trainingData.csv* i zwrócenia przechowywanych w nim danych wraz z ich etykietami.
|
||||
|
||||
|
||||
def readTrainingData(path):
|
||||
with open(path) as csv_file:
|
||||
csvReader = csv.reader(csv_file, delimiter=',')
|
||||
lineCount = 0
|
||||
trainingData = []
|
||||
labels = []
|
||||
|
||||
for row in csvReader:
|
||||
example = []
|
||||
for column in row:
|
||||
if lineCount == 0:
|
||||
labels.append(column)
|
||||
else:
|
||||
if column.isdigit():
|
||||
example.append(int(column))
|
||||
else:
|
||||
example.append(column)
|
||||
if lineCount > 0:
|
||||
trainingData.append(example)
|
||||
lineCount += 1
|
||||
|
||||
print('Processed lines: ', lineCount)
|
||||
return trainingData, labels
|
||||
|
||||
Jest wykorzystana w inicjalizacji zmiennej globalnej *trainigData* oraz *labels*, z których korzystają niektóre z metod tej klasy.
|
||||
|
||||
```
|
||||
trainingData, labels = SuggestionTree.readTrainingData(
|
||||
'src/SubprojectMarcinDobrowolski/Data/trainingData.csv')
|
||||
```
|
||||
|
||||
Metoda *gini* oblicza tzw. *gini impurity*. Jest to miara, która mówi jak duże jest prawdopodobieństwo by losowo wybrany element ze zbioru został błędnie oznaczony.
|
||||
|
||||
def gini(rows):
|
||||
counts = classCount(rows)
|
||||
impurity = 1
|
||||
for lbl in counts:
|
||||
prob_of_lbl = counts[lbl] / float(len(rows))
|
||||
impurity -= prob_of_lbl**2
|
||||
return impurity
|
||||
|
||||
def infoGain(left, right, currentUncertainty):
|
||||
p = float(len(left)) / (len(left) + len(right))
|
||||
return currentUncertainty - p * SuggestionTree.gini(left) - (1 - p) * SuggestionTree.gini(right)
|
||||
|
||||
Natomiast metoda *infoGain* oblicza w jakim stopniu zmniejsza się wskaźnik *gini impurity* ze względu na wybór lewego i prawego podziału zbioru.
|
||||
|
||||
Metoda *findBestSplit* znajduje pytanie, które w danej chwili spowoduje jak największy przyrost informacji.
|
||||
|
||||
def findBestSplit(rows):
|
||||
bestGain = 0
|
||||
bestQuestion = None
|
||||
currentUncertainty = SuggestionTree.gini(rows)
|
||||
nFeatures = len(labels) - 1
|
||||
|
||||
for column in range(nFeatures):
|
||||
|
||||
values = set([row[column] for row in rows])
|
||||
|
||||
for value in values:
|
||||
|
||||
question = Question(labels[column], column, value)
|
||||
|
||||
trueRows, falseRows = partition(rows, question)
|
||||
|
||||
if len(trueRows) == 0 or len(falseRows) == 0:
|
||||
continue
|
||||
|
||||
gain = SuggestionTree.infoGain(
|
||||
trueRows, falseRows, currentUncertainty)
|
||||
|
||||
if gain > bestGain:
|
||||
bestGain, bestQuestion = gain, question
|
||||
|
||||
return bestGain, bestQuestion
|
||||
|
||||
Klasa *Leaf* jest reprezentacją liścia drzewa czyli najbardziej zewnętrznego wierzchołka. W atrybucie *predictions* przechowuje oszacowanie etykiet/y, które może przyjąć przykład, który do niego trafi. Metoda *printLeaf* służy do eleganckiego wyświetlania oszacowania.
|
||||
|
||||
class Leaf:
|
||||
def __init__(self, rows):
|
||||
self.predictions = classCount(rows)
|
||||
|
||||
def printLeaf(self):
|
||||
total = sum(self.predictions.values()) * 1.0
|
||||
probs = {}
|
||||
for label in self.predictions.keys():
|
||||
probs[label] = str(
|
||||
int(self.predictions[label] / total * 100)) + '%'
|
||||
return probs
|
||||
|
||||
Klasa *DecisionNode* reprezentuje wierzchołek drzewa, w którym następuje zadanie pytania oraz podział danych. Atrybuty *trueBranch* oraz *falseBranch* reprezentują odpowiednio lewego i prawego potomka.
|
||||
|
||||
class DecisionNode:
|
||||
def __init__(self,
|
||||
question,
|
||||
trueBranch,
|
||||
falseBranch):
|
||||
self.question = question
|
||||
self.trueBranch = trueBranch
|
||||
self.falseBranch = falseBranch
|
||||
|
||||
Metoda *buildTree* odpowiada za rekurencyjne skonstruowanie drzewa na bazie przekazanego zbioru danych. Zwraca korzeń drzewa.
|
||||
|
||||
def buildTree(rows):
|
||||
gain, question = SuggestionTree.findBestSplit(rows)
|
||||
|
||||
if gain == 0:
|
||||
return Leaf(rows)
|
||||
|
||||
trueRows, falseRows = partition(rows, question)
|
||||
|
||||
trueBranch = SuggestionTree.buildTree(trueRows)
|
||||
falseBranch = SuggestionTree.buildTree(falseRows)
|
||||
|
||||
return DecisionNode(question, trueBranch, falseBranch)
|
||||
|
||||
Metoda *printTree* wyświetla strukturę drzewa w postaci tekstu w wierszu poleceń.
|
||||
|
||||
def printTree(node, spacing=' '):
|
||||
|
||||
if isinstance(node, Leaf):
|
||||
print(spacing + 'Predict', node.predictions)
|
||||
return
|
||||
|
||||
print(spacing + str(node.question))
|
||||
|
||||
print(spacing + '--> True:')
|
||||
SuggestionTree.printTree(node.trueBranch, spacing + ' ')
|
||||
|
||||
print(spacing + '--> False:')
|
||||
SuggestionTree.printTree(node.falseBranch, spacing + ' ')
|
||||
|
||||
def classify(row, node):
|
||||
|
||||
if isinstance(node, Leaf):
|
||||
return node
|
||||
|
||||
if node.question.match(row):
|
||||
return SuggestionTree.classify(row, node.trueBranch)
|
||||
else:
|
||||
return SuggestionTree.classify(row, node.falseBranch)
|
||||
|
||||
Metoda *classify* przyporządkowuje danemu przykładowi odpowiadającą mu etykiete w drzewie.
|
||||
|
||||
### Działanie
|
||||
|
||||
Drzewo jest budowane na podstawie zbioru uczącego. Algorytm podejmuje decyzje o tym jak podzielić zbiór na podstawie przyrostu informacji wynikającego z potencjalnego podziału. Przerywa działanie gdy przyrost wynosi *0*, co oznacza, że dotarł do liścia drzewa i dalszy podział nie jest możliwy. Kończąc swoje działanie zwraca korzeń klasy *DecisionNode*, który zawiera w atrybutach potomków> Umożliwia to rekurencyjne schodzenie po drzewie, które jest realizowane przez metodę *classify*.
|
||||
|
||||
### Integracja z projektem
|
||||
|
||||
Deklaracja oraz inicjalizacja korzenia drzewa:
|
||||
|
||||
suggestionTreeRoot = SuggestionTree.buildTree(trainingData)
|
||||
|
||||
W projekcie można sprawdzić działanie podprojektu poprzez wciśnięcie, któregoś z klawiszy:
|
||||
|
||||
* 0 - sprawdzenie poprawności algorytmu na losowo wygenerowanym zbiorze testowym.
|
||||
|
||||
generateTestData('testData.csv', 100)
|
||||
testData = []
|
||||
with open('src/SubprojectMarcinDobrowolski/Data/testData.csv') as csv_file:
|
||||
csvReader = csv.reader(csv_file, delimiter=',')
|
||||
lineCount = 0
|
||||
for row in csvReader:
|
||||
example = []
|
||||
for column in row:
|
||||
if column.isdigit():
|
||||
example.append(int(column))
|
||||
else:
|
||||
example.append(column)
|
||||
if lineCount > 0:
|
||||
testData.append(example)
|
||||
lineCount += 1
|
||||
|
||||
print('Processed lines: ', lineCount)
|
||||
|
||||
print('Test examples predictions:')
|
||||
for example in testData:
|
||||
print('{} - {}'.format(example, SuggestionTree.classify(
|
||||
example, suggestionTreeRoot).printLeaf()))
|
||||
|
||||
* 1 - sprawdzenie poprawności algorytmu na losowo wygenerowanym przykładzie.
|
||||
|
||||
example = generateTestExample()
|
||||
print('Test example prediction: ')
|
||||
print('{} - {}'.format(example, SuggestionTree.classify(
|
||||
example, suggestionTreeRoot).printLeaf()))
|
||||
|
||||
* 2 - wyświetlenie struktury drzewa
|
||||
|
||||
SuggestionTree.printTree(suggestionTreeRoot)
|
||||
|
||||
Rezultaty zostają wypisane w konsoli.
|
Loading…
Reference in New Issue
Block a user