algorytm id3
This commit is contained in:
parent
1c26165161
commit
8bcb2066e1
148
drzewo_decyzyjne/drzewo_decyzyjne.py
Normal file
148
drzewo_decyzyjne/drzewo_decyzyjne.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Wczytywanie danych
|
||||||
|
nasze_dane = pd.read_csv("data")
|
||||||
|
|
||||||
|
|
||||||
|
# Obliczanie entropii dla całego zbioru danych
|
||||||
|
def oblicz_calkowita_entropie(dane_treningowe, etykieta, lista_klas):
|
||||||
|
liczba_wierszy = dane_treningowe.shape[0]
|
||||||
|
calkowita_entropia = 0
|
||||||
|
|
||||||
|
for klasa in lista_klas:
|
||||||
|
liczba_wystapien_klasy = dane_treningowe[dane_treningowe[etykieta] == klasa].shape[0]
|
||||||
|
entropia_klasy = - (liczba_wystapien_klasy / liczba_wierszy) * np.log2(liczba_wystapien_klasy / liczba_wierszy)
|
||||||
|
calkowita_entropia += entropia_klasy
|
||||||
|
|
||||||
|
return calkowita_entropia
|
||||||
|
|
||||||
|
|
||||||
|
# Obliczanie entropii dla przefiltrowanego zbioru danych
|
||||||
|
def oblicz_entropie(dane_wartosci_cechy, etykieta, lista_klas):
|
||||||
|
liczba_wystapien_cechy = dane_wartosci_cechy.shape[0]
|
||||||
|
entropia = 0
|
||||||
|
|
||||||
|
for klasa in lista_klas:
|
||||||
|
liczba_wystapien_klasy = dane_wartosci_cechy[dane_wartosci_cechy[etykieta] == klasa].shape[0]
|
||||||
|
entropia_klasy = 0
|
||||||
|
|
||||||
|
if liczba_wystapien_klasy != 0:
|
||||||
|
prawdopodobienstwo_klasy = liczba_wystapien_klasy / liczba_wystapien_cechy
|
||||||
|
entropia_klasy = - prawdopodobienstwo_klasy * np.log2(prawdopodobienstwo_klasy)
|
||||||
|
|
||||||
|
entropia += entropia_klasy
|
||||||
|
|
||||||
|
return entropia
|
||||||
|
|
||||||
|
|
||||||
|
# Obliczanie przyrostu informacji dla danej cechy
|
||||||
|
def oblicz_przyrost_informacji(nazwa_cechy, dane_treningowe, etykieta, lista_klas):
|
||||||
|
unikalne_wartosci_cechy = dane_treningowe[nazwa_cechy].unique()
|
||||||
|
liczba_wierszy = dane_treningowe.shape[0]
|
||||||
|
informacja_cechy = 0.0
|
||||||
|
|
||||||
|
for wartosc_cechy in unikalne_wartosci_cechy:
|
||||||
|
dane_wartosci_cechy = dane_treningowe[dane_treningowe[nazwa_cechy] == wartosc_cechy]
|
||||||
|
liczba_wystapien_wartosci_cechy = dane_wartosci_cechy.shape[0]
|
||||||
|
entropia_wartosci_cechy = oblicz_entropie(dane_wartosci_cechy, etykieta, lista_klas)
|
||||||
|
prawdopodobienstwo_wartosci_cechy = liczba_wystapien_wartosci_cechy / liczba_wierszy
|
||||||
|
informacja_cechy += prawdopodobienstwo_wartosci_cechy * entropia_wartosci_cechy
|
||||||
|
|
||||||
|
return oblicz_calkowita_entropie(dane_treningowe, etykieta, lista_klas) - informacja_cechy
|
||||||
|
|
||||||
|
|
||||||
|
# Znajdowanie najbardziej informatywnej cechy (cechy o najwyższym przyroście informacji)
|
||||||
|
def znajdz_najbardziej_informatywna_ceche(dane_treningowe, etykieta, lista_klas):
|
||||||
|
lista_cech = dane_treningowe.columns.drop(etykieta)
|
||||||
|
# Etykieta nie jest cechą, więc ją usuwamy
|
||||||
|
max_przyrost_informacji = -1
|
||||||
|
najbardziej_informatywna_cecha = None
|
||||||
|
|
||||||
|
for cecha in lista_cech:
|
||||||
|
przyrost_informacji_cechy = oblicz_przyrost_informacji(cecha, dane_treningowe, etykieta, lista_klas)
|
||||||
|
|
||||||
|
if max_przyrost_informacji < przyrost_informacji_cechy:
|
||||||
|
max_przyrost_informacji = przyrost_informacji_cechy
|
||||||
|
najbardziej_informatywna_cecha = cecha
|
||||||
|
|
||||||
|
return najbardziej_informatywna_cecha
|
||||||
|
|
||||||
|
|
||||||
|
# Dodawanie węzła do drzewa
|
||||||
|
def generuj_poddrzewo(nazwa_cechy, dane_treningowe, etykieta, lista_klas):
|
||||||
|
slownik_licznosci_wartosci_cechy = dane_treningowe[nazwa_cechy].value_counts(sort=False)
|
||||||
|
drzewo = {}
|
||||||
|
|
||||||
|
for wartosc_cechy, liczba in slownik_licznosci_wartosci_cechy.items():
|
||||||
|
dane_wartosci_cechy = dane_treningowe[dane_treningowe[nazwa_cechy] == wartosc_cechy]
|
||||||
|
|
||||||
|
przypisany_do_wezla = False
|
||||||
|
for klasa in lista_klas:
|
||||||
|
liczba_klasy = dane_wartosci_cechy[dane_wartosci_cechy[etykieta] == klasa].shape[0]
|
||||||
|
|
||||||
|
if liczba_klasy == liczba:
|
||||||
|
drzewo[wartosc_cechy] = klasa
|
||||||
|
dane_treningowe = dane_treningowe[dane_treningowe[nazwa_cechy] != wartosc_cechy]
|
||||||
|
przypisany_do_wezla = True
|
||||||
|
if not przypisany_do_wezla:
|
||||||
|
drzewo[wartosc_cechy] = "?"
|
||||||
|
|
||||||
|
return drzewo, dane_treningowe
|
||||||
|
|
||||||
|
|
||||||
|
# Wykonywanie algorytmu ID3 i generowanie drzewa
|
||||||
|
def generuj_drzewo(korzen, poprzednia_wartosc_cechy, dane_treningowe, etykieta, lista_klas):
|
||||||
|
if dane_treningowe.shape[0] != 0:
|
||||||
|
najbardziej_informatywna_cecha = znajdz_najbardziej_informatywna_ceche(dane_treningowe, etykieta, lista_klas)
|
||||||
|
drzewo, dane_treningowe = generuj_poddrzewo(najbardziej_informatywna_cecha, dane_treningowe, etykieta, lista_klas)
|
||||||
|
nastepny_korzen = None
|
||||||
|
|
||||||
|
if poprzednia_wartosc_cechy is not None:
|
||||||
|
korzen[poprzednia_wartosc_cechy] = dict()
|
||||||
|
korzen[poprzednia_wartosc_cechy][najbardziej_informatywna_cecha] = drzewo
|
||||||
|
nastepny_korzen = korzen[poprzednia_wartosc_cechy][najbardziej_informatywna_cecha]
|
||||||
|
else:
|
||||||
|
korzen[najbardziej_informatywna_cecha] = drzewo
|
||||||
|
nastepny_korzen = korzen[najbardziej_informatywna_cecha]
|
||||||
|
|
||||||
|
for wezel, galezie in list(nastepny_korzen.items()):
|
||||||
|
if galezie == "?":
|
||||||
|
dane_wartosci_cechy = dane_treningowe[dane_treningowe[najbardziej_informatywna_cecha] == wezel]
|
||||||
|
generuj_drzewo(nastepny_korzen, wezel, dane_wartosci_cechy, etykieta, lista_klas)
|
||||||
|
|
||||||
|
|
||||||
|
# Znajdowanie unikalnych klas etykiety i rozpoczęcie algorytmu
|
||||||
|
def id3(nasze_dane, etykieta):
|
||||||
|
dane_treningowe = nasze_dane.copy()
|
||||||
|
drzewo = {}
|
||||||
|
lista_klas = dane_treningowe[etykieta].unique()
|
||||||
|
generuj_drzewo(drzewo, None, dane_treningowe, etykieta, lista_klas)
|
||||||
|
return drzewo
|
||||||
|
|
||||||
|
|
||||||
|
# Przewidywanie na podstawie drzewa
|
||||||
|
def przewiduj(drzewo, instancja):
|
||||||
|
if not isinstance(drzewo, dict):
|
||||||
|
return drzewo
|
||||||
|
else:
|
||||||
|
korzen = next(iter(drzewo))
|
||||||
|
wartosc_cechy = instancja[korzen]
|
||||||
|
if wartosc_cechy in drzewo[korzen]:
|
||||||
|
return przewiduj(drzewo[korzen][wartosc_cechy], instancja)
|
||||||
|
else:
|
||||||
|
return 'walcz'
|
||||||
|
|
||||||
|
|
||||||
|
drzewo = id3(nasze_dane, 'akcja')
|
||||||
|
|
||||||
|
przyklad = {'zdrowie_bohatera': '100',
|
||||||
|
'moc_bohatera': 'nie',
|
||||||
|
'moc_moba': 'nie',
|
||||||
|
'lvl_wiekszy_bohater': 'tak',
|
||||||
|
'mob_jest_strzelcem': 'nie',
|
||||||
|
'zdrowie_moba': '1',
|
||||||
|
'artefakt': 'tak'}
|
||||||
|
|
||||||
|
print(przewiduj(drzewo, przyklad))
|
||||||
|
print(drzewo)
|
Loading…
Reference in New Issue
Block a user