11 KiB
Podprojekt indywidualny - raport
Temat: Sugerowanie potraw Metoda uczenia: Decision Tree(CART) Autor: Marcin Jerzy Dobrowolski Źródła: https://github.com/random-forests/tutorials/blob/master/decision_tree.ipynb
Link do repozytorium projektu: https://git.wmi.amu.edu.pl/s444427/Sztuczna_Inteligencja_2020
Wstęp
Projekt rozwiązuje problem zasugerowania dania przez kelnera na podstawie danych reprezentujących preferencje gościa. W tym celu wykorzystałem metodę uczenia drzew decyzyjnych CART, którą zaprogramowałem w czystym języku python. Struktura podprojektu:
- suggestionDecisionTree.py - klasa drzew decyzyjnego
- question.py - klasa pytania
- leaf.py - klasa liścia
- decisionNode.py - klasa wierzchołka pytającego
- utility.py - klasa użytkowa
- trainingData.csv - zbiór uczący
- testData.csv - klasa użytkowa
Struktura danych reprezentujących preferencje gościa:
type, size, money, appetite, name
['meat', 'small', 31, 55, 'schabowy']
- type - typ potrawy: wegańska lub z mięsem
- appetite - apetyt gościa, świadczy o tym jak duża potrawa spełni jego oczekiwania,
np. duże danie będzie inne dla osób o dużym i małym apetycie - size - rozmiar dania
- money - zawartość portfela gościa, czyli na jak dużo może sobie pozwolić
- name - etykieta potrawy, dopasowana w taki sposób by spełnić powyższe wymagania
Opis modułów
W pliku utility znajdują się podręczne funkcje użytkowe wykorzystywane przez pozostałe moduły programu.
def uniqueValues(rows, column):
return set([row[column] for row in rows])
Znajduje unikalne wartości wyznaczonej kolumny z podanego zbioru danych.
def classCount(rows):
counts = {}
for row in rows:
label = row[-1]
if label not in counts:
counts[label] = 0
counts[label] += 1
return counts
Podlicza ilość elementów każdej etykiety z danego zbioru danych.
def is_numeric(value):
return isinstance(value, int) or isinstance(value, float)
Sprawdza czy podana wartość jest liczbą.
def partition(rows, question):
trueRows, falseRows = [], []
for row in rows:
if question.match(row):
trueRows.append(row)
else:
falseRows.append(row)
return trueRows, falseRows
Dzieli dany zbiór ze względu na zadane pytanie.
def generateTestData(path, quantity):
with open('path', 'w') as csvFile:
csvWriter = csv.writer(csvFile, delimiter=',')
...
Generuje zbiór danych testowych o zadanej wielkości i zapisuje go w pliku o podanej ścieżce. Ze względu na długość kodu nie zamieściłem go w całości. Dane są generowane przy użyciu liczb pseudo losowych w taki sposób by były zgodne z prawami rządzącymi tym zbiorem danych.
def generateTestExample():
example = []
category = random.randrange(0, 1)
size = random.randrange(0, 2)
if category == 0:
example.append('meat')
...
Działa na takiej samej zasadzie jak jej poprzedniczka. Z tym wyjątkiem, że generuje pojedynczy przykład.
Klasa Question reprezentuje pytanie, które służy do podziału danych w wierzchołkach decyzyjnych drzewa. Jej atrybuty oznajczają kolejno: etykiete kolumny zbioru danych, indeks wspomnianej kolumny oraz wartość, której dotyczyy pytanie.
class Question:
def __init__(self, columnLabel, column, value):
self.columnLabel = columnLabel
self.column = column
self.value = value
def match(self, example):
val = example[self.column]
if is_numeric(val):
return val <= self.value
else:
return val == self.value
Metoda match decyduje jak na pytanie odpowiada zadany przykład.
Jądro algorytmu znajduje się w klasie suggestionTree w pilku suggestionDecisionTree. Metoda readTrainingData służy do odczytania zbioru uczącego z pliku trainingData.csv i zwrócenia przechowywanych w nim danych wraz z ich etykietami.
def readTrainingData(path):
with open(path) as csv_file:
csvReader = csv.reader(csv_file, delimiter=',')
lineCount = 0
trainingData = []
labels = []
for row in csvReader:
example = []
for column in row:
if lineCount == 0:
labels.append(column)
else:
if column.isdigit():
example.append(int(column))
else:
example.append(column)
if lineCount > 0:
trainingData.append(example)
lineCount += 1
print('Processed lines: ', lineCount)
return trainingData, labels
Jest wykorzystana w inicjalizacji zmiennej globalnej trainigData oraz labels, z których korzystają niektóre z metod tej klasy.
trainingData, labels = SuggestionTree.readTrainingData(
'src/SubprojectMarcinDobrowolski/Data/trainingData.csv')
Metoda gini oblicza tzw. gini impurity. Jest to miara, która mówi jak duże jest prawdopodobieństwo by losowo wybrany element ze zbioru został błędnie oznaczony.
def gini(rows):
counts = classCount(rows)
impurity = 1
for lbl in counts:
prob_of_lbl = counts[lbl] / float(len(rows))
impurity -= prob_of_lbl**2
return impurity
def infoGain(left, right, currentUncertainty):
p = float(len(left)) / (len(left) + len(right))
return currentUncertainty - p * SuggestionTree.gini(left) - (1 - p) * SuggestionTree.gini(right)
Natomiast metoda infoGain oblicza w jakim stopniu zmniejsza się wskaźnik gini impurity ze względu na wybór lewego i prawego podziału zbioru.
Metoda findBestSplit znajduje pytanie, które w danej chwili spowoduje jak największy przyrost informacji.
def findBestSplit(rows):
bestGain = 0
bestQuestion = None
currentUncertainty = SuggestionTree.gini(rows)
nFeatures = len(labels) - 1
for column in range(nFeatures):
values = set([row[column] for row in rows])
for value in values:
question = Question(labels[column], column, value)
trueRows, falseRows = partition(rows, question)
if len(trueRows) == 0 or len(falseRows) == 0:
continue
gain = SuggestionTree.infoGain(
trueRows, falseRows, currentUncertainty)
if gain > bestGain:
bestGain, bestQuestion = gain, question
return bestGain, bestQuestion
Klasa Leaf jest reprezentacją liścia drzewa czyli najbardziej zewnętrznego wierzchołka. W atrybucie predictions przechowuje oszacowanie etykiet/y, które może przyjąć przykład, który doń trafi. Metoda printLeaf służy do eleganckiego wyświetlania oszacowania.
class Leaf:
def __init__(self, rows):
self.predictions = classCount(rows)
def printLeaf(self):
total = sum(self.predictions.values()) * 1.0
probs = {}
for label in self.predictions.keys():
probs[label] = str(
int(self.predictions[label] / total * 100)) + '%'
return probs
Klasa DecisionNode reprezentuje wierzchołek drzewa, w którym następuje zadanie pytania oraz podział danych. Atrybuty trueBranch oraz falseBranch reprezentują odpowiednio lewego i prawego potomka.
class DecisionNode:
def __init__(self,
question,
trueBranch,
falseBranch):
self.question = question
self.trueBranch = trueBranch
self.falseBranch = falseBranch
Metoda buildTree odpowiada za rekurencyjne skonstruowanie drzewa na bazie przekazanego zbioru danych. Zwraca korzeń drzewa.
def buildTree(rows):
gain, question = SuggestionTree.findBestSplit(rows)
if gain == 0:
return Leaf(rows)
trueRows, falseRows = partition(rows, question)
trueBranch = SuggestionTree.buildTree(trueRows)
falseBranch = SuggestionTree.buildTree(falseRows)
return DecisionNode(question, trueBranch, falseBranch)
Metoda printTree wyświetla strukturę drzewa w postaci tekstu w wierszu poleceń.
def printTree(node, spacing=' '):
if isinstance(node, Leaf):
print(spacing + 'Predict', node.predictions)
return
print(spacing + str(node.question))
print(spacing + '--> True:')
SuggestionTree.printTree(node.trueBranch, spacing + ' ')
print(spacing + '--> False:')
SuggestionTree.printTree(node.falseBranch, spacing + ' ')
def classify(row, node):
if isinstance(node, Leaf):
return node
if node.question.match(row):
return SuggestionTree.classify(row, node.trueBranch)
else:
return SuggestionTree.classify(row, node.falseBranch)
Metoda classify przyporządkowuje danemu przykładowi odpowiadającą mu etykiete w drzewie.
Działanie
Drzewo jest budowane na podstawie zbioru uczącego. Algorytm podejmuje decyzje o tym jak podzielić zbiór na podstawie przyrostu informacji wynikającego z potencjalnego podziału. Przerywa działanie gdy przyrost wynosi 0, co oznacza, że dotarł do liścia drzewa i podział nie jest możliwy. Kończąc swoje działanie zwraca korzeń klasy DecisionNode, który zawiera w atrybutach potomków co umożliwia rekurencyjne schodzenie po drzewie.
W projekcie można je wykorzystać poprzez wciśnięcie, któregoś z klawiszy:
- 0 - sprawdzenie poprawności algorytmu na losowo wygenerowanym zbiorze testowym. =
- 1 - sprawdzenie poprawności algorytmu na losowo wygenerowanym przykładzie.
- 2 - wyświetlenie struktury drzewa
Rezultaty zostają wypisane w konsoli.
module: pygames
Python: 3.7.7
macOS / Linux
pygames
python main.py
Windows
pygames
python main.py