5.1 KiB
Agata Lenz - drzewa decyzyjne, algorytm ID3
Wymagania
Importowane biblioteki:
- numpy
- pandas
Opis podprojektu
Podprojekt implementuje algorytm ID3 pozwalający wyznaczyć drzewo decyzyjne, przy użyciu którego agent (traktor) podejmie decyzję co posadzić w danym miejscu, na podstawie:
- previous - gatunku rośliny, która poprzednio rosła w danym miejscu
- pH - pH gleby
- dry_level - suchości gleby
Uczenie modelu
Dane
Dane uczące jak i testowe znajdują się w pliku data.py. Dane uczące są zapisane w postaci listy, której elementy to przykładowe dane w formacie ['previous', 'pH', 'dry_level', 'label'], gdzie label oznacza posadzoną w wymienionych warunkach roślinę. Łącznie zestaw uczący zawiera 47 elementów, a zestaw testowy - 15.
Przykładowe elementy zestawu uczącego:
['pumpkin', 'neutral', 'dry', 'beetroot'],
['none', 'neutral', 'wet', 'pumpkin'],
['cabbage', 'alkaline', 'soaking wet', 'none']
Implementacja algorytmu ID3
Za budowę drzewa decyzyjnego odpowiada rekurencyjna funkcja ID3(data, original_data, features, target_attribute_name="label", parent_node_class=None):
def ID3(data, original_data, features, target_attribute_name="label", parent_node_class=None):
"""
Algorytm ID3
parametry:
data zbiór danych, dla którego poszukujemy drzewa decyzyjnego
original_data oryginalny zbiór danych (zwracany gdy data == None)
features lista atrybutów wejściowego zbioru
target_attribute_name docelowy atrybut, który chcemy przewidzieć
parent_node_class nadrzędna wartość
"""
# Jeżeli wszystkie atrybuty są takie same, zwracamy liść z pierwszą napotkaną wartością
if len(np.unique(data[target_attribute_name])) <= 1:
return np.unique(data[target_attribute_name])[0]
elif len(data) == 0:
return np.unique(original_data[target_attribute_name])[
np.argmax(np.unique(original_data[target_attribute_name], return_counts=True)[1])]
elif len(features) == 0:
return parent_node_class
else:
# Aktualizacja nadrzędnej wartości
parent_node_class = np.unique(data[target_attribute_name])[
np.argmax(np.unique(data[target_attribute_name], return_counts=True)[1])]
# Obliczenie przyrostu informacji dla każdego potencjalnego atrybutu,
# według którego nastąpi podział zbioru
item_values = [info_gain(data, feature, target_attribute_name) for feature in
features]
# Najlepszym atrybutem jest ten o największym przyroście informacji
best_feature_index = np.argmax(item_values)
best_feature = features[best_feature_index]
# Struktura drzewa
tree = {best_feature: {}}
# Aktualizacja zbioru atrybutów
features = [i for i in features if i != best_feature]
# Dla każdej wartości wybranego atrybutu budujemy kolejne poddrzewo
for value in np.unique(data[best_feature]):
sub_data = data.where(data[best_feature] == value).dropna()
subtree = ID3(sub_data, data, features, target_attribute_name, parent_node_class)
tree[best_feature][value] = subtree
return (tree)
Do obliczenia przyrostu informacji służy funkcja info_gain(data, split_attribute_name, target_name="label"), która dla wejściowego zestawu danych (data), oblicza jego entropię oraz średnią ważoną entropii każdego podzestawu (wyznaczanego przez unikalne wartości atrybutu split_attribute_name):
def info_gain(data, split_attribute_name, target_name="label"):
total_entropy = entropy(data[target_name])
vals, counts = np.unique(data[split_attribute_name], return_counts=True)
weighted_entropy = np.sum(
[(counts[i] / np.sum(counts)) * entropy(data.where(data[split_attribute_name] == vals[i]).dropna()[target_name])
for i in range(len(vals))])
information_gain = total_entropy - weighted_entropy
return information_gain
Implementacja w projekcie głównym
Funkcja, z której będzie korzystał traktor, aby podjąć decyzję o posadzeniu rośliny, to def decide_to_plant(soil)
def decide_to_plant(soil):
if soil.have_plant():
plant = soil.get_plant()
if plant.collect() == 'True':
info = get_info(soil)
plant.leave_soil()
else:
return [['none']]
else:
info = get_info(soil)
data = []
data.append(info)
predicted = predict_data(data)
grow_a_plant(soil,predicted[0][0])
return predicted
Pierwszym krokiem jest sprawdzenie, czy w danej ziemi już coś rośnie - jeżeli jest to roślina dojrzała, zostaje ona zebrana, a jej nazwa staje się wartością atrybutu previous. Jeżeli nie, nie przewiduje się sadzenia w tym miejscu żadnej rośliny w danym momencie. Funkcja get_info(soil) zwraca listę parametrów obiektu Soil, potrzebną do poszukiwań w drzewie decyzyjnym. W pliku dataset.py znajduje się funkcja create_data_soil() pozwalająca przetestować działanie algorytmu na obiektach typu Soil