Gra-SI/drzewo_decyzyjne.py
2023-06-16 12:46:47 +02:00

151 lines
6.1 KiB
Python

import pandas as pd
import numpy as np
class Tree():
# Obliczanie entropii dla całego zbioru danych
def oblicz_calkowita_entropie(self,dane_treningowe, etykieta, lista_klas):
liczba_wierszy = dane_treningowe.shape[0]
calkowita_entropia = 0
for klasa in lista_klas:
liczba_wystapien_klasy = dane_treningowe[dane_treningowe[etykieta] == klasa].shape[0]
entropia_klasy = - (liczba_wystapien_klasy / liczba_wierszy) * np.log2(liczba_wystapien_klasy / liczba_wierszy)
calkowita_entropia += entropia_klasy
return calkowita_entropia
# Obliczanie entropii dla przefiltrowanego zbioru danych
def oblicz_entropie(self,dane_wartosci_cechy, etykieta, lista_klas):
liczba_wystapien_cechy = dane_wartosci_cechy.shape[0]
entropia = 0
for klasa in lista_klas:
liczba_wystapien_klasy = dane_wartosci_cechy[dane_wartosci_cechy[etykieta] == klasa].shape[0]
entropia_klasy = 0
if liczba_wystapien_klasy != 0:
prawdopodobienstwo_klasy = liczba_wystapien_klasy / liczba_wystapien_cechy
entropia_klasy = - prawdopodobienstwo_klasy * np.log2(prawdopodobienstwo_klasy)
entropia += entropia_klasy
return entropia
# Obliczanie przyrostu informacji dla danej cechy
def oblicz_przyrost_informacji(self,nazwa_cechy, dane_treningowe, etykieta, lista_klas):
unikalne_wartosci_cechy = dane_treningowe[nazwa_cechy].unique()
liczba_wierszy = dane_treningowe.shape[0]
informacja_cechy = 0.0
for wartosc_cechy in unikalne_wartosci_cechy:
dane_wartosci_cechy = dane_treningowe[dane_treningowe[nazwa_cechy] == wartosc_cechy]
liczba_wystapien_wartosci_cechy = dane_wartosci_cechy.shape[0]
entropia_wartosci_cechy = self.oblicz_entropie(dane_wartosci_cechy, etykieta, lista_klas)
prawdopodobienstwo_wartosci_cechy = liczba_wystapien_wartosci_cechy / liczba_wierszy
informacja_cechy += prawdopodobienstwo_wartosci_cechy * entropia_wartosci_cechy
return self.oblicz_calkowita_entropie(dane_treningowe, etykieta, lista_klas) - informacja_cechy
# Znajdowanie najbardziej informatywnej cechy (cechy o najwyższym przyroście informacji)
def znajdz_najbardziej_informatywna_ceche(self,dane_treningowe, etykieta, lista_klas):
lista_cech = dane_treningowe.columns.drop(etykieta)
# Etykieta nie jest cechą, więc ją usuwamy
max_przyrost_informacji = -1
najbardziej_informatywna_cecha = None
for cecha in lista_cech:
przyrost_informacji_cechy = self.oblicz_przyrost_informacji(cecha, dane_treningowe, etykieta, lista_klas)
if max_przyrost_informacji < przyrost_informacji_cechy:
max_przyrost_informacji = przyrost_informacji_cechy
najbardziej_informatywna_cecha = cecha
return najbardziej_informatywna_cecha
# Dodawanie węzła do drzewa
def generuj_poddrzewo(self,nazwa_cechy, dane_treningowe, etykieta, lista_klas):
slownik_licznosci_wartosci_cechy = dane_treningowe[nazwa_cechy].value_counts(sort=False)
drzewo = {}
for wartosc_cechy, liczba in slownik_licznosci_wartosci_cechy.items():
dane_wartosci_cechy = dane_treningowe[dane_treningowe[nazwa_cechy] == wartosc_cechy]
przypisany_do_wezla = False
for klasa in lista_klas:
liczba_klasy = dane_wartosci_cechy[dane_wartosci_cechy[etykieta] == klasa].shape[0]
if liczba_klasy == liczba:
drzewo[wartosc_cechy] = klasa
dane_treningowe = dane_treningowe[dane_treningowe[nazwa_cechy] != wartosc_cechy]
przypisany_do_wezla = True
if not przypisany_do_wezla:
drzewo[wartosc_cechy] = "?"
return drzewo, dane_treningowe
# Wykonywanie algorytmu ID3 i generowanie drzewa
def generuj_drzewo(self,korzen, poprzednia_wartosc_cechy, dane_treningowe, etykieta, lista_klas):
if dane_treningowe.shape[0] != 0:
najbardziej_informatywna_cecha = self.znajdz_najbardziej_informatywna_ceche(dane_treningowe, etykieta, lista_klas)
drzewo, dane_treningowe = self.generuj_poddrzewo(najbardziej_informatywna_cecha, dane_treningowe, etykieta, lista_klas)
nastepny_korzen = None
if poprzednia_wartosc_cechy is not None:
korzen[poprzednia_wartosc_cechy] = dict()
korzen[poprzednia_wartosc_cechy][najbardziej_informatywna_cecha] = drzewo
nastepny_korzen = korzen[poprzednia_wartosc_cechy][najbardziej_informatywna_cecha]
else:
korzen[najbardziej_informatywna_cecha] = drzewo
nastepny_korzen = korzen[najbardziej_informatywna_cecha]
for wezel, galezie in list(nastepny_korzen.items()):
if galezie == "?":
dane_wartosci_cechy = dane_treningowe[dane_treningowe[najbardziej_informatywna_cecha] == wezel]
self.generuj_drzewo(nastepny_korzen, wezel, dane_wartosci_cechy, etykieta, lista_klas)
# Znajdowanie unikalnych klas etykiety i rozpoczęcie algorytmu
def id3(self,nasze_dane, etykieta):
dane_treningowe = nasze_dane.copy()
drzewo = {}
lista_klas = dane_treningowe[etykieta].unique()
self.generuj_drzewo(drzewo, None, dane_treningowe, etykieta, lista_klas)
return drzewo
# Przewidywanie na podstawie drzewa
def przewiduj(self,drzewo, instancja):
if not isinstance(drzewo, dict):
return drzewo
else:
korzen = next(iter(drzewo))
wartosc_cechy = instancja[korzen]
if wartosc_cechy in drzewo[korzen]:
return self.przewiduj(drzewo[korzen][wartosc_cechy], instancja)
else:
return 'walcz'
def tree(self,przyklad):
# Wczytywanie danych
nasze_dane = pd.read_csv("data")
drzewo = self.id3(nasze_dane, 'akcja')
return self.przewiduj(drzewo, przyklad)
#print(przewiduj(drzewo, przyklad))
#print(drzewo)