DSIC-Bayes-continuous/Bayes.ipynb

428 KiB
Raw Blame History

Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe)

Skład grupy:

  • Nowak Ania,
  • Łaźna Patrycja,
  • Bregier Damian
#!pip install pandas==1.2.4
#!pip install numpy==1.20.3
#!pip install sklearn==0.0

from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import typing
import os, pickle
from sklearn.metrics import confusion_matrix, accuracy_score
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

0. Podstawowe informacje o zbiorze danych

W projekcie wykorzystany został GTZAN Dataset poruszający problem wieloklasowej klasyfikacji danych na przykładzie gatunków muzycznych. Zbiór ten składa się z 10 gatunków obejmujacych: blues, muzykę klasyczną, country, disco, hip-hop, jazz, pop, reggae oraz rock. Każdy ze wspomnianych gatunków jest reprezentowany przez 100 plików audio o długości 30 sekund, a same próbki były zbierane w latach 2000-2001 ze zdyfersyfikowanych źródeł obejmujących: stacje radiowe, prywatne płyty CD oraz nagrania własne.

Zbiór danych jest niezwykle bogaty i rozbudowany, ponieważ do każdego utworu zostało przypisanych 60 unikalnych parametrów. Parametry te obejmują takie dane jak: długość utworu, etykietę z nazwą gatunku, tempo, harmoniczność, variancję czy częstotliwość melodyczną (MFCC).

Dokładne dane na temat tego zbioru danych można znaleźć pod adresem: https://www.kaggle.com/andradaolteanu/gtzan-dataset-music-genre-classification

1. Wczytywanie i normalizacja danych

# Słownik zawierający 10 gatunków muzycznych, które zostały sparowane z
# odpowiadającymi im wartościami numerycznymi
genre_dict = {
    "blues" : 1,
    "classical" : 2,
    "country" : 3,
    "disco" : 4,
    "hiphop" : 5,
    "jazz" : 6,
    "metal" : 7,
    "pop" : 8,
    "reggae" : 9,
    "rock" : 10
}
# nazwa pliku w którym umieszczane są parametry po wstępnym przetworzeniu
filename = 'music_genre.csv'
model_path = 'model.model'
# skrypt ten realizuje dwie podstawowe funkcje
# 1) sprawdza czy plik music_genre.csv istnieje i jeżeli tak to wczytuje go
# 2) w przeciwnym przypadku dokonuje preprocessingu danych w ramach którego
#    gatunki zamieniane są na wartości licznowe, a wartości takie jak nazwa 
#    pliku, etykieta czy długość są usuwane
 
if os.path.isfile(filename):
    print("Loading prepared data...")
    data = pd.read_csv(filename)
else:
    print("Preparing data...")
    data = pd.read_csv('music_genre_raw.csv')
    column = data["label"].apply(lambda x: genre_dict[x])
    data.insert(0, 'genre', column, 'int')
    data = data.drop(columns=['filename', 'length'])
    data.to_csv(filename, index=False)
display(data.head(10))

data.columns
Preparing data...
genre chroma_stft_mean chroma_stft_var rms_mean rms_var spectral_centroid_mean spectral_centroid_var spectral_bandwidth_mean spectral_bandwidth_var rolloff_mean ... mfcc16_var mfcc17_mean mfcc17_var mfcc18_mean mfcc18_var mfcc19_mean mfcc19_var mfcc20_mean mfcc20_var label
0 1 0.350088 0.088757 0.130228 0.002827 1784.165850 1.297741e+05 2002.449060 85882.761315 3805.839606 ... 52.420910 -1.690215 36.524071 -0.408979 41.597103 -2.303523 55.062923 1.221291 46.936035 blues
1 1 0.340914 0.094980 0.095948 0.002373 1530.176679 3.758501e+05 2039.036516 213843.755497 3550.522098 ... 55.356403 -0.731125 60.314529 0.295073 48.120598 -0.283518 51.106190 0.531217 45.786282 blues
2 1 0.363637 0.085275 0.175570 0.002746 1552.811865 1.564676e+05 1747.702312 76254.192257 3042.260232 ... 40.598766 -7.729093 47.639427 -1.816407 52.382141 -3.439720 46.639660 -2.231258 30.573025 blues
3 1 0.404785 0.093999 0.141093 0.006346 1070.106615 1.843559e+05 1596.412872 166441.494769 2184.745799 ... 44.427753 -3.319597 50.206673 0.636965 37.319130 -0.619121 37.259739 -3.407448 31.949339 blues
4 1 0.308526 0.087841 0.091529 0.002303 1835.004266 3.433999e+05 1748.172116 88445.209036 3579.757627 ... 86.099236 -5.454034 75.269707 -0.916874 53.613918 -4.404827 62.910812 -11.703234 55.195160 blues
5 1 0.302456 0.087532 0.103494 0.003981 1831.993940 1.030482e+06 1729.653287 201910.508633 3481.517592 ... 72.549225 -1.838263 68.702026 -2.783800 42.447453 -3.047909 39.808784 -8.109991 46.311005 blues
6 1 0.291328 0.093981 0.141874 0.008803 1459.366472 4.378594e+05 1389.009131 185023.239545 2795.610963 ... 83.248245 -10.913176 56.902153 -6.971336 38.231800 -3.436505 48.235741 -6.483466 70.170364 blues
7 1 0.307955 0.092903 0.131822 0.005531 1451.667066 4.495682e+05 1577.270941 168211.938804 2954.836760 ... 70.438438 -10.568935 52.090893 -10.784515 60.461330 -4.690678 65.547516 -8.630722 56.401436 blues
8 1 0.408879 0.086512 0.142416 0.001507 1719.368948 1.632828e+05 2031.740381 105542.718193 3782.316288 ... 50.563751 -7.041824 28.894934 2.695248 36.889568 3.412305 33.698597 -2.715692 36.418430 blues
9 1 0.273950 0.092316 0.081314 0.004347 1817.150863 2.982361e+05 1973.773306 114070.112591 3943.490565 ... 59.314602 -1.916804 58.418438 -2.292661 83.205231 2.881967 77.082222 -4.235203 91.468811 blues

10 rows × 59 columns

Index(['genre', 'chroma_stft_mean', 'chroma_stft_var', 'rms_mean', 'rms_var',
       'spectral_centroid_mean', 'spectral_centroid_var',
       'spectral_bandwidth_mean', 'spectral_bandwidth_var', 'rolloff_mean',
       'rolloff_var', 'zero_crossing_rate_mean', 'zero_crossing_rate_var',
       'harmony_mean', 'harmony_var', 'perceptr_mean', 'perceptr_var', 'tempo',
       'mfcc1_mean', 'mfcc1_var', 'mfcc2_mean', 'mfcc2_var', 'mfcc3_mean',
       'mfcc3_var', 'mfcc4_mean', 'mfcc4_var', 'mfcc5_mean', 'mfcc5_var',
       'mfcc6_mean', 'mfcc6_var', 'mfcc7_mean', 'mfcc7_var', 'mfcc8_mean',
       'mfcc8_var', 'mfcc9_mean', 'mfcc9_var', 'mfcc10_mean', 'mfcc10_var',
       'mfcc11_mean', 'mfcc11_var', 'mfcc12_mean', 'mfcc12_var', 'mfcc13_mean',
       'mfcc13_var', 'mfcc14_mean', 'mfcc14_var', 'mfcc15_mean', 'mfcc15_var',
       'mfcc16_mean', 'mfcc16_var', 'mfcc17_mean', 'mfcc17_var', 'mfcc18_mean',
       'mfcc18_var', 'mfcc19_mean', 'mfcc19_var', 'mfcc20_mean', 'mfcc20_var',
       'label'],
      dtype='object')

2. Podział danych na zbiory: uczący i testowy

# Podział ten jest dokonywany w proporcji 80:20, gdzie 80% danych trafia do zbioru uczącego, a 20%
# do zbioru testowego, podejście to jest standardową praktyką w dziedzinie uczenia maszynwego

# wartość X reprezentuje 57 parametrów opisujących poszczególne utwory
X = data.drop(["genre"], axis=1)
# wartość Y zawiera kolumnę gatunków wyrażonych przy pomocy wartości liczbowych od 1 do 10
Y = data["genre"]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.20, random_state = False)
display(X_train.head(10))
chroma_stft_mean chroma_stft_var rms_mean rms_var spectral_centroid_mean spectral_centroid_var spectral_bandwidth_mean spectral_bandwidth_var rolloff_mean rolloff_var ... mfcc16_var mfcc17_mean mfcc17_var mfcc18_mean mfcc18_var mfcc19_mean mfcc19_var mfcc20_mean mfcc20_var label
687 0.516547 0.072241 0.267380 0.001175 3338.581900 172002.893292 2697.128636 45771.294278 6670.863091 3.556853e+05 ... 37.339474 -8.121326 33.968277 4.910113 42.063385 -2.474697 35.162354 3.192656 36.478157 metal
500 0.344511 0.085002 0.046747 0.001542 1503.869486 554576.511533 1754.216082 283554.933422 2799.283099 2.685679e+06 ... 50.311016 -1.503434 41.141155 0.221949 55.707256 -1.991485 50.006485 -3.353825 49.906403 jazz
332 0.368345 0.090390 0.111073 0.004402 2446.919077 490397.099115 2449.159840 215375.540632 4958.057490 2.650020e+06 ... 78.892769 -1.054999 79.877068 4.496278 112.834435 -0.978958 75.059898 -5.256925 120.275269 disco
979 0.360042 0.083953 0.116724 0.000789 2148.410463 253618.158995 2107.165355 72155.551685 4479.264304 9.787046e+05 ... 37.060532 -13.479134 50.848667 3.308529 47.726006 -3.704957 56.781952 1.085497 54.243389 rock
817 0.425788 0.091852 0.139799 0.003601 1803.774378 659241.158049 1973.418903 201432.199120 3777.969679 2.632339e+06 ... 64.068756 -2.219202 99.249870 5.304260 64.088127 -6.597187 62.661850 -2.923168 67.490440 reggae
620 0.495959 0.072854 0.117362 0.000867 2657.912854 189139.438926 2345.662472 32730.579626 5358.261979 5.918222e+05 ... 27.937113 -10.676390 26.519361 3.875155 25.613684 -4.943561 24.334734 3.255899 25.199259 metal
814 0.395137 0.093939 0.114246 0.004025 1716.249594 920189.339374 2062.885827 358557.016423 3790.901258 4.734865e+06 ... 66.090370 -4.590122 72.595345 4.261040 63.185764 -2.127876 50.693245 -3.665569 89.750290 reggae
516 0.249535 0.087563 0.060560 0.001276 1465.857446 143302.098295 1738.858902 58868.399307 2822.406728 7.392007e+05 ... 109.811813 -0.027696 113.660950 2.098475 160.025497 1.109709 136.810165 2.935807 95.914490 jazz
518 0.353474 0.087755 0.052264 0.000316 1993.352766 64753.479332 2127.165109 36027.039069 4248.194549 3.987029e+05 ... 57.230133 -1.110214 48.080849 -0.784249 57.033504 -2.984207 55.737625 0.350456 64.126846 jazz
940 0.416089 0.087772 0.142935 0.003150 3009.958707 435134.775688 2778.049758 135548.871316 6131.200719 1.788624e+06 ... 42.315434 -3.953057 48.761936 -3.092345 49.514446 -2.731183 58.219994 -0.909785 63.111858 rock

10 rows × 58 columns

Ilość krotek dla poszczególnych gatunków z podziałem na zbiory: uczący i testowy

# skrypt odpowiadający za przeiterowanie po słowniku i zliczenie liczebności poszczególnych gatunków
# w ramach podziału na zbiory: uczący i testowy

for key in genre_dict.keys():
    count = len(data[data["genre"]==genre_dict[key]])
    count_train = len(X_train[Y_train==genre_dict[key]])
    count_test = len(X_test[Y_test==genre_dict[key]])
    print(f"{key}\ttest: {count_test}\ttrain: {count_train}\tall: {count}")
blues	test: 15	train: 85	all: 100
classical	test: 11	train: 89	all: 100
country	test: 27	train: 73	all: 100
disco	test: 22	train: 78	all: 100
hiphop	test: 23	train: 77	all: 100
jazz	test: 18	train: 82	all: 100
metal	test: 20	train: 80	all: 100
pop	test: 24	train: 76	all: 100
reggae	test: 15	train: 85	all: 100
rock	test: 25	train: 75	all: 100

3. Wizualizacja danych

Boxplot dla tempa gatunków

Jedną z najciekawszych i najbardziej intuicyjnych wartości mierzalnych dla poszczególnych utworów jest tempo. Parametr ten został przedstawiony przy pomocy wykresu pudełkowego w odniesieniu do wspomnianych wcześniej 10 gatunków muzycznych.

Ze zgromadzonych danych jednoznacznie wynika, że najwyższą medianę dla tempa mają utwory z gatunku Reggee, zaś na drugim i trzecim miejscu znajdują się odpowiednio muzyka klasyczna oraz blues. Podczas gdy najniższe wartości mają gatunki hip-hop oraz pop.

Z kolei największe rozbieżności pomiędzy wartościami zauważalne są w przypadku muzyki klasycznej, country i metalu, chociaż najwięcej obserwacji odstających pojawia się w przypadku hiphopu oraz popu.

f, ax = plt.subplots(figsize=(16, 9));

sns.boxplot(x = "label", y = "tempo", data = data[["label", "tempo"]], palette = 'pastel');

plt.title('Zależność pomiędzy tempem a gatunkiem', fontsize = 25)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 10);
plt.xlabel("Genre", fontsize = 15)
plt.ylabel("Tempo", fontsize = 15);

Boxplot dla średnich melowych współczynników cepstralnych sygnału dla poszczególnych gatunków

Interesujące wyniki pojawiły się także na wykresie pokazującym zależność pomiędzy MFCC mean, czyli średnimi wartościami dla melowych współczynników cepstralnych sygnału a gatunkami muzycznimi.

Najwyższe wartości MFCC_mean dotyczą metalu oraz bluesa, podczas gdy najniższe wartości uzyskiwane są w przypadu popu i muzyki klasycznej. Z kolei najwięcej obserwacji odstających pojawia się w przypadku reggae.

f, ax = plt.subplots(figsize=(16, 9));
sns.boxplot(x = "label", y = "mfcc4_mean", data = data[["label", "mfcc4_mean"]], palette = 'pastel');

plt.title('Zależność między MFCC a gatunkiem', fontsize = 25)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 10);
plt.xlabel("Genre", fontsize = 15)
plt.ylabel("mfcc4_mean4", fontsize = 15);

Korelacja między cechami średnimi

W procesie badania zależności pomiędzy dostępnymi cechami wykorzystana została mapa ciepła, która jednak pokazała, że w wielu przypadkach korelacje nie zachodzą, co jest szczególnie widoczne w przypadku średniej częstotliwości melodycznej cepstrum2 (mfcc2_mean), a jeżeli takowe korelacje zachodza to mają stosunkowo niewielkie wartości.Występowanie zależności widać w górnej oraz środkowej częsci mapy.

mean_cols = [col for col in data.columns if 'mean' in col]
mean_correlation = data[mean_cols].corr()


mask = np.triu(np.ones_like(mean_correlation, dtype=bool))
f, ax = plt.subplots(figsize=(16, 11))
cmap = sns.diverging_palette(150, 275, as_cmap=True, s = 90, l = 45, n = 5)

sns.heatmap(mean_correlation, mask=mask, cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5)

plt.title('Correlation Heatmap (means)', fontsize = 20)
plt.xticks(fontsize = 10)
plt.yticks(fontsize = 10);

Korelacja między cechami wariancji

Odwrotna sytuacja ma miejsce w przypadku mapy ciepła dla cech wariancji, w przypadku ktorych korelacja nie zachodzi wyłącznie dla dwóch parametrów czyli harmony i perceptr w środkowej cześci wykresu. Z kolei stosunkow wysokie wartości korelacji można zaobserwować dla parametrów "skrajnych", czyli pierwszych i ostatnich na liście parametrów.

var_cols = [col for col in data.columns if 'var' in col]
var_correlation = data[var_cols].corr()


mask = np.triu(np.ones_like(var_correlation, dtype=bool))
f, ax = plt.subplots(figsize=(16, 11))
cmap = sns.diverging_palette(240, 10, as_cmap=True, s = 90, l = 45, n = 5)

sns.heatmap(var_correlation, mask=mask, cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5)

plt.title('Correlation Heatmap (vars)', fontsize = 20)
plt.xticks(fontsize = 10)
plt.yticks(fontsize = 10);

Cechy średnie dają lepsze rezultaty niż cechy wariancji ze względu na mniejszą korelację pomiędzy poszczególnymi parametrami.(https://datascience.stackexchange.com/questions/9087/correlation-and-naive-bayes).

W toku przeprowadzanych testów okazało sie, że dokładność stworzonego i wytrenowanego modelu zależy od rodzaju cech. W przypadku cech wariancji dokładność jest niższa niż w przypadku cech średnich. Z kolei najwyższą dokładność udało się uzyskać poprzez wykorzystanie kombinacji 8 różnych kolumn.

  • dla var_cols accuracy = 0.3875,
  • dla mean_cols accuracy = 0.4375,
  • dla ['mfcc4_mean', 'mfcc12_mean', 'mfcc9_var', 'mfcc1_mean', 'rms_mean', 'chroma_stft_mean', 'mfcc6_var', 'mfcc9_mean'] accuracy = 0.56125

Równocześnie uzyskane wyniki mogłyby mieć zdecydowanie wyższą dokładność jednak ograniczeniem okazał się specyfika samego datasetu, który posiada niewielkie zróżnicowanie wartości cech i niewielką korelację pomiędzy poszczególnymi cechami!

Wykres punktowy przedstawiający zależność pomiędzy chroma_stft_mean (wysokością dźwięku) a mfcc12_mean (melowym współczynnikiem cepstralnym )

Wykres ten pokazuje, że chociaż pojawia się pewna pula obserwacji odstających, to jednak wraz ze wzrostem wartości mfcc12_mean rosną wartości chroma_stft_mean. Tym samym zależność pomiędzy tymi dwiema wartościami, w ogólności, ma charakter liniowy, co potwierdza wynik uzyskany na heatmapie, gdzie korelacja wynosiła 0.2.

fig = plt.figure(figsize=(10,10))
chart = fig.add_subplot()
ax = fig.add_subplot()
colors = ['red', 'green', 'blue', 'brown','purple', 'gray', 'pink', 'black', 'yellow', 'orange']
for genre in genre_dict:
    genre_data = data[data["genre"]==genre_dict[genre]]
    ax.scatter(genre_data['chroma_stft_mean'],genre_data['mfcc12_mean'], c=colors[genre_dict[genre]-1])
plt.show()
<ipython-input-10-bcf947f4e0ac>:3: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.
  ax = fig.add_subplot()

4. Wykorzystanie algorytmu Bayesa

class NaiveBayesContinues:
    def __init__(self, X, Y):
        self.classes = Y.unique()
        self.priors = [] # prawdopodobieństwo każdej z klas
        self.stds = [] #lista odchyleń standardowych każdej z cech dla każdej z klas
        self.means = [] #lista średnich dla każdej z cech dla każdej z klas
        for c in self.classes:
            x_with_c_class = X[c == Y]
            self.priors.append(len(x_with_c_class) / len(X))
            self.means.append(x_with_c_class.mean(axis=0))
            self.stds.append(x_with_c_class.std(axis=0))

            
    def predict(self, X, display_results=False):
        y_preds = []
        for x in X:
            posteriors = []
            for i, c in enumerate(self.classes):
                prior = self.priors[i] # prawdopodobieństwo dla rozpatrywanej klasy
                mean = self.means[i] # średnia cech dla rozpatrywanej klasy
                std = self.stds[i] # odchylenie standardowe cech dla rozpatrywanej klasy
                
                posterior = 1 #P(X1|Yi)*P(X2|Yi)*P(X3|Yi)...
                for j, feature in (enumerate(x)):
                    P_X_yi = np.exp((-(feature - mean[j]) ** 2) / (2 * std[j] ** 2)) / np.sqrt(2 * np.pi * std[j] ** 2) #P(Xj|Yi)
                    posterior *= P_X_yi
                
                posterior = (posterior * prior) #P(Yi)P(X1|Yi)*P(X2|Yi)*P(X3|Yi)...
                posteriors.append(posterior)
                
            if(display_results):
                print("posteriors")
                print(posteriors)
                print(np.argmax(posteriors))
                
            y_pred = self.classes[np.argmax(posteriors)] # Wzięcie klasy z największym prawdopodobieństem
            y_preds.append(y_pred)
        return y_preds

Wzór bayesa image-2.png

Gausowski Naiwny Bayes Stosowany w przypadku pracy na danych o charakterze ciągłym. image.png

W procesie trenowania i testowania modelu wykorzystany został skrypt losujący kolumny i zapisujący uzyskiwane wartości accuracy w celu znalezienia najbardziej efektywnej kombinacji cech. W ten sposób wybranych zostało 8 cech, w tym sześć cech należących do kategorii średnich i dwie do wariancji.

X_train_np = X_train[['mfcc4_mean', 'mfcc12_mean', 'mfcc9_var', 'mfcc1_mean', 'rms_mean', 'chroma_stft_mean', 'mfcc6_var', 'mfcc9_mean']].to_numpy()
X_test_np = X_test[['mfcc4_mean', 'mfcc12_mean', 'mfcc9_var', 'mfcc1_mean', 'rms_mean', 'chroma_stft_mean', 'mfcc6_var', 'mfcc9_mean']].to_numpy()

model = NaiveBayesContinues(X_train_np, Y_train)
Y_train_predicted = model.predict(X_train_np[:1])

Ewaluacja

Zbiór trenujący

Y_train_predicted = model.predict(X_train_np)
cm = confusion_matrix(Y_train, Y_train_predicted)
ac = accuracy_score(Y_train, Y_train_predicted)
print("(Train data) Confusion matrix:")
display(cm)
print("(Train data) Accuracy:")
print(ac)
(Train data) Confusion matrix:
array([[27,  3, 11,  3,  1,  8, 18,  0, 10,  4],
       [ 1, 75,  2,  0,  0, 10,  0,  0,  1,  0],
       [10,  1, 29,  9,  1,  4,  2,  8,  6,  3],
       [ 2,  0,  2, 39,  1,  1, 12, 12,  5,  4],
       [ 0,  0,  0,  8, 36,  0, 11, 13,  8,  1],
       [ 8, 20,  2,  0,  0, 45,  1,  0,  0,  6],
       [ 1,  0,  0,  3,  2,  0, 71,  0,  0,  3],
       [ 1,  1,  2,  2,  5,  0,  0, 63,  2,  0],
       [ 1,  0,  8,  8,  6,  3,  3,  5, 48,  3],
       [ 4,  0, 10, 15,  0,  6, 14,  5,  5, 16]], dtype=int64)
(Train data) Accuracy:
0.56125

Zbiór testowy

Y_test_predicted = model.predict(X_test_np)
cm = confusion_matrix(Y_test, Y_test_predicted)
ac = accuracy_score(Y_test, Y_test_predicted)
print("Confusion matrix:")
display(cm)
print("Accuracy:")
print(ac)
Confusion matrix:
array([[ 5,  0,  2,  0,  0,  2,  4,  0,  1,  1],
       [ 0,  8,  0,  0,  0,  3,  0,  0,  0,  0],
       [ 7,  0,  6,  6,  0,  4,  1,  0,  1,  2],
       [ 0,  0,  2,  7,  1,  0,  2,  4,  5,  1],
       [ 0,  0,  0,  2, 10,  0,  6,  1,  4,  0],
       [ 0,  3,  0,  1,  0, 14,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  1,  0, 17,  0,  2,  0],
       [ 1,  0,  0,  2,  1,  0,  0, 18,  1,  1],
       [ 0,  1,  0,  0,  3,  1,  0,  4,  6,  0],
       [ 5,  0,  3,  5,  1,  1,  5,  1,  0,  4]], dtype=int64)
Accuracy:
0.475

Przykładowe porównania

for i in range(10):
    print(f"Y: {Y_test.to_numpy()[i]}\tPredicted: {Y_test_predicted[i]}")
Y: 10	Predicted: 10
Y: 9	Predicted: 8
Y: 3	Predicted: 1
Y: 6	Predicted: 6
Y: 7	Predicted: 7
Y: 10	Predicted: 7
Y: 1	Predicted: 1
Y: 3	Predicted: 6
Y: 4	Predicted: 4
Y: 8	Predicted: 10

Bayes z wykorzystaniem gotowej biblioteki

from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix, accuracy_score
import pandas as pd
import numpy as np
import pickle, os
import typing

class Bayes:
    def __init__(self):
        self.classifier = GaussianNB()


    def train(self, X: pd.DataFrame, Y: pd.Series) -> None:
        self.classifier.fit(X, Y)


    def predict(self, X: pd.DataFrame) -> np.ndarray:
        predictions = self.classifier.predict(X)
        return predictions


    def eval(self, Y: pd.Series, Y_pred: np.ndarray) -> typing.Tuple[np.ndarray, np.float64]:
        cm = confusion_matrix(Y, Y_pred)
        ac = accuracy_score(Y, Y_pred)
        return (cm, ac)
bayes = Bayes()
bayes.train(X_train_np, Y_train)

Y_predicted = bayes.predict(X_train_np)
eval_result = bayes.eval(Y_train, Y_predicted)
print("Train:")
print(eval_result[1])

Y_predicted = bayes.predict(X_test_np)
eval_result = bayes.eval(Y_test, Y_predicted)
print("Test:")
print(eval_result[1])
Train:
0.56125
Test:
0.475
# skrypt losujacy kolumny ze zbioru i sprawdzajacy accuracy na zbiorze trenujacym

for i in range(100):
    X = data.drop(["genre", "label", "tempo"], axis=1)
    X_rand = X.sample(n=10, axis='columns')
    Y = data["genre"] 
    
    X_train, X_test, Y_train, Y_test = train_test_split(X_rand, Y, test_size = 0.20, random_state = False)
    
    model = GaussianNB()
    model.fit(X_train, Y_train)
    Y_train_predicted = model.predict(X_train)
    ac = accuracy_score(Y_train, Y_train_predicted)
    filename = 'accuracy.txt'

    if os.path.exists(filename):
        append_write = 'a'
    else:
        append_write = 'w'

    acc_random = open(filename, append_write)
    acc_random.write(str(ac) + " " + str(list(X_rand.columns)) + '\n')
    acc_random.close()

#!sort -k1,1nr -k2,2 accuracy.txt