401 KiB
Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe)
Skład grupy:
- Nowak Ania,
- Łaźna Patrycja,
- Bregier Damian
#!pip install pandas==1.2.4
#!pip install numpy==1.20.3
#!pip install sklearn==0.0
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import typing
import os, pickle
from sklearn.metrics import confusion_matrix, accuracy_score
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
0. Podstawowe informacje o zbiorze danych
W projekcie wykorzystany został GTZAN Dataset poruszający problem wieloklasowej klasyfikacji danych na przykładzie gatunków muzycznych. Zbiór ten składa się z 10 gatunków obejmujacych: blues, muzykę klasyczną, country, disco, hip-hop, jazz, pop, reggae oraz rock. Każdy ze wspomnianych gatunków jest reprezentowany przez 100 plików audio o długości 30 sekund, a same próbki były zbierane w latach 2000-2001 ze zdyfersyfikowanych źródeł obejmujących: stacje radiowe, prywatne płyty CD oraz nagrania własne.
Zbiór danych jest niezwykle bogaty i rozbudowany, ponieważ do każdego utworu zostało przypisanych 60 unikalnych parametrów. Parametry te obejmują takie dane jak: długość utworu, etykietę z nazwą gatunku, tempo, harmoniczność, variancję czy częstotliwość melodyczną (MFCC).
Dokładne dane na temat tego zbioru danych można znaleźć pod adresem: https://www.kaggle.com/andradaolteanu/gtzan-dataset-music-genre-classification
1. Wczytywanie i normalizacja danych
# Słownik zawierający 10 gatunków muzycznych, które zostały sparowane z
# odpowiadającymi im wartościami numerycznymi
genre_dict = {
"blues" : 1,
"classical" : 2,
"country" : 3,
"disco" : 4,
"hiphop" : 5,
"jazz" : 6,
"metal" : 7,
"pop" : 8,
"reggae" : 9,
"rock" : 10
}
# nazwa pliku w którym umieszczane są parametry po wstępnym przetworzeniu
filename = 'music_genre.csv'
model_path = 'model.model'
# skrypt ten realizuje dwie podstawowe funkcje
# 1) sprawdza czy plik music_genre.csv istnieje i jeżeli tak to wczytuje go
# 2) w przeciwnym przypadku dokonuje preprocessingu danych w ramach którego
# gatunki zamieniane są na wartości licznowe, a wartości takie jak nazwa
# pliku, etykieta czy długość są usuwane
if os.path.isfile(filename):
print("Loading prepared data...")
data = pd.read_csv(filename)
else:
print("Preparing data...")
data = pd.read_csv('music_genre_raw.csv')
column = data["label"].apply(lambda x: genre_dict[x])
data.insert(0, 'genre', column, 'int')
data = data.drop(columns=['filename', 'length'])
data.to_csv(filename, index=False)
display(data.head(10))
data.columns
Loading prepared data...
genre | chroma_stft_mean | chroma_stft_var | rms_mean | rms_var | spectral_centroid_mean | spectral_centroid_var | spectral_bandwidth_mean | spectral_bandwidth_var | rolloff_mean | ... | mfcc16_var | mfcc17_mean | mfcc17_var | mfcc18_mean | mfcc18_var | mfcc19_mean | mfcc19_var | mfcc20_mean | mfcc20_var | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0.350088 | 0.088757 | 0.130228 | 0.002827 | 1784.165850 | 1.297741e+05 | 2002.449060 | 85882.761315 | 3805.839606 | ... | 52.420910 | -1.690215 | 36.524071 | -0.408979 | 41.597103 | -2.303523 | 55.062923 | 1.221291 | 46.936035 | blues |
1 | 1 | 0.340914 | 0.094980 | 0.095948 | 0.002373 | 1530.176679 | 3.758501e+05 | 2039.036516 | 213843.755497 | 3550.522098 | ... | 55.356403 | -0.731125 | 60.314529 | 0.295073 | 48.120598 | -0.283518 | 51.106190 | 0.531217 | 45.786282 | blues |
2 | 1 | 0.363637 | 0.085275 | 0.175570 | 0.002746 | 1552.811865 | 1.564676e+05 | 1747.702312 | 76254.192257 | 3042.260232 | ... | 40.598766 | -7.729093 | 47.639427 | -1.816407 | 52.382141 | -3.439720 | 46.639660 | -2.231258 | 30.573025 | blues |
3 | 1 | 0.404785 | 0.093999 | 0.141093 | 0.006346 | 1070.106615 | 1.843559e+05 | 1596.412872 | 166441.494769 | 2184.745799 | ... | 44.427753 | -3.319597 | 50.206673 | 0.636965 | 37.319130 | -0.619121 | 37.259739 | -3.407448 | 31.949339 | blues |
4 | 1 | 0.308526 | 0.087841 | 0.091529 | 0.002303 | 1835.004266 | 3.433999e+05 | 1748.172116 | 88445.209036 | 3579.757627 | ... | 86.099236 | -5.454034 | 75.269707 | -0.916874 | 53.613918 | -4.404827 | 62.910812 | -11.703234 | 55.195160 | blues |
5 | 1 | 0.302456 | 0.087532 | 0.103494 | 0.003981 | 1831.993940 | 1.030482e+06 | 1729.653287 | 201910.508633 | 3481.517592 | ... | 72.549225 | -1.838263 | 68.702026 | -2.783800 | 42.447453 | -3.047909 | 39.808784 | -8.109991 | 46.311005 | blues |
6 | 1 | 0.291328 | 0.093981 | 0.141874 | 0.008803 | 1459.366472 | 4.378594e+05 | 1389.009131 | 185023.239545 | 2795.610963 | ... | 83.248245 | -10.913176 | 56.902153 | -6.971336 | 38.231800 | -3.436505 | 48.235741 | -6.483466 | 70.170364 | blues |
7 | 1 | 0.307955 | 0.092903 | 0.131822 | 0.005531 | 1451.667066 | 4.495682e+05 | 1577.270941 | 168211.938804 | 2954.836760 | ... | 70.438438 | -10.568935 | 52.090893 | -10.784515 | 60.461330 | -4.690678 | 65.547516 | -8.630722 | 56.401436 | blues |
8 | 1 | 0.408879 | 0.086512 | 0.142416 | 0.001507 | 1719.368948 | 1.632828e+05 | 2031.740381 | 105542.718193 | 3782.316288 | ... | 50.563751 | -7.041824 | 28.894934 | 2.695248 | 36.889568 | 3.412305 | 33.698597 | -2.715692 | 36.418430 | blues |
9 | 1 | 0.273950 | 0.092316 | 0.081314 | 0.004347 | 1817.150863 | 2.982361e+05 | 1973.773306 | 114070.112591 | 3943.490565 | ... | 59.314602 | -1.916804 | 58.418438 | -2.292661 | 83.205231 | 2.881967 | 77.082222 | -4.235203 | 91.468811 | blues |
10 rows × 59 columns
Index(['genre', 'chroma_stft_mean', 'chroma_stft_var', 'rms_mean', 'rms_var', 'spectral_centroid_mean', 'spectral_centroid_var', 'spectral_bandwidth_mean', 'spectral_bandwidth_var', 'rolloff_mean', 'rolloff_var', 'zero_crossing_rate_mean', 'zero_crossing_rate_var', 'harmony_mean', 'harmony_var', 'perceptr_mean', 'perceptr_var', 'tempo', 'mfcc1_mean', 'mfcc1_var', 'mfcc2_mean', 'mfcc2_var', 'mfcc3_mean', 'mfcc3_var', 'mfcc4_mean', 'mfcc4_var', 'mfcc5_mean', 'mfcc5_var', 'mfcc6_mean', 'mfcc6_var', 'mfcc7_mean', 'mfcc7_var', 'mfcc8_mean', 'mfcc8_var', 'mfcc9_mean', 'mfcc9_var', 'mfcc10_mean', 'mfcc10_var', 'mfcc11_mean', 'mfcc11_var', 'mfcc12_mean', 'mfcc12_var', 'mfcc13_mean', 'mfcc13_var', 'mfcc14_mean', 'mfcc14_var', 'mfcc15_mean', 'mfcc15_var', 'mfcc16_mean', 'mfcc16_var', 'mfcc17_mean', 'mfcc17_var', 'mfcc18_mean', 'mfcc18_var', 'mfcc19_mean', 'mfcc19_var', 'mfcc20_mean', 'mfcc20_var', 'label'], dtype='object')
2. Podział danych na zbiory: uczący i testowy
# Podział ten jest dokonywany w proporcji 80:20, gdzie 80% danych trafia do zbioru uczącego, a 20%
# do zbioru testowego, podejście to jest standardową praktyką w dziedzinie uczenia maszynwego
# wartość X reprezentuje 57 parametrów opisujących poszczególne utwory
X = data.drop(["genre"], axis=1)
# wartość Y zawiera kolumnę gatunków wyrażonych przy pomocy wartości liczbowych od 1 do 10
Y = data["genre"]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.20, random_state = False)
display(X_train.head(10))
chroma_stft_mean | chroma_stft_var | rms_mean | rms_var | spectral_centroid_mean | spectral_centroid_var | spectral_bandwidth_mean | spectral_bandwidth_var | rolloff_mean | rolloff_var | ... | mfcc16_var | mfcc17_mean | mfcc17_var | mfcc18_mean | mfcc18_var | mfcc19_mean | mfcc19_var | mfcc20_mean | mfcc20_var | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
687 | 0.516547 | 0.072241 | 0.267380 | 0.001175 | 3338.581900 | 172002.893292 | 2697.128636 | 45771.294278 | 6670.863091 | 3.556853e+05 | ... | 37.339474 | -8.121326 | 33.968277 | 4.910113 | 42.063385 | -2.474697 | 35.162354 | 3.192656 | 36.478157 | metal |
500 | 0.344511 | 0.085002 | 0.046747 | 0.001542 | 1503.869486 | 554576.511533 | 1754.216082 | 283554.933422 | 2799.283099 | 2.685679e+06 | ... | 50.311016 | -1.503434 | 41.141155 | 0.221949 | 55.707256 | -1.991485 | 50.006485 | -3.353825 | 49.906403 | jazz |
332 | 0.368345 | 0.090390 | 0.111073 | 0.004402 | 2446.919077 | 490397.099115 | 2449.159840 | 215375.540632 | 4958.057490 | 2.650020e+06 | ... | 78.892769 | -1.054999 | 79.877068 | 4.496278 | 112.834435 | -0.978958 | 75.059898 | -5.256925 | 120.275269 | disco |
979 | 0.360042 | 0.083953 | 0.116724 | 0.000789 | 2148.410463 | 253618.158995 | 2107.165355 | 72155.551685 | 4479.264304 | 9.787046e+05 | ... | 37.060532 | -13.479134 | 50.848667 | 3.308529 | 47.726006 | -3.704957 | 56.781952 | 1.085497 | 54.243389 | rock |
817 | 0.425788 | 0.091852 | 0.139799 | 0.003601 | 1803.774378 | 659241.158049 | 1973.418903 | 201432.199120 | 3777.969679 | 2.632339e+06 | ... | 64.068756 | -2.219202 | 99.249870 | 5.304260 | 64.088127 | -6.597187 | 62.661850 | -2.923168 | 67.490440 | reggae |
620 | 0.495959 | 0.072854 | 0.117362 | 0.000867 | 2657.912854 | 189139.438926 | 2345.662472 | 32730.579626 | 5358.261979 | 5.918222e+05 | ... | 27.937113 | -10.676390 | 26.519361 | 3.875155 | 25.613684 | -4.943561 | 24.334734 | 3.255899 | 25.199259 | metal |
814 | 0.395137 | 0.093939 | 0.114246 | 0.004025 | 1716.249594 | 920189.339374 | 2062.885827 | 358557.016423 | 3790.901258 | 4.734865e+06 | ... | 66.090370 | -4.590122 | 72.595345 | 4.261040 | 63.185764 | -2.127876 | 50.693245 | -3.665569 | 89.750290 | reggae |
516 | 0.249535 | 0.087563 | 0.060560 | 0.001276 | 1465.857446 | 143302.098295 | 1738.858902 | 58868.399307 | 2822.406728 | 7.392007e+05 | ... | 109.811813 | -0.027696 | 113.660950 | 2.098475 | 160.025497 | 1.109709 | 136.810165 | 2.935807 | 95.914490 | jazz |
518 | 0.353474 | 0.087755 | 0.052264 | 0.000316 | 1993.352766 | 64753.479332 | 2127.165109 | 36027.039069 | 4248.194549 | 3.987029e+05 | ... | 57.230133 | -1.110214 | 48.080849 | -0.784249 | 57.033504 | -2.984207 | 55.737625 | 0.350456 | 64.126846 | jazz |
940 | 0.416089 | 0.087772 | 0.142935 | 0.003150 | 3009.958707 | 435134.775688 | 2778.049758 | 135548.871316 | 6131.200719 | 1.788624e+06 | ... | 42.315434 | -3.953057 | 48.761936 | -3.092345 | 49.514446 | -2.731183 | 58.219994 | -0.909785 | 63.111858 | rock |
10 rows × 58 columns
Ilość krotek dla poszczególnych gatunków z podziałem na zbiory: uczący i testowy
# skrypt odpowiadający za przeiterowanie po słowniku i zliczenie liczebności poszczególnych gatunków
# w ramach podziału na zbiory: uczący i testowy
for key in genre_dict.keys():
count = len(data[data["genre"]==genre_dict[key]])
count_train = len(X_train[Y_train==genre_dict[key]])
count_test = len(X_test[Y_test==genre_dict[key]])
print(f"{key}\ttest: {count_test}\ttrain: {count_train}\tall: {count}")
blues test: 15 train: 85 all: 100 classical test: 11 train: 89 all: 100 country test: 27 train: 73 all: 100 disco test: 22 train: 78 all: 100 hiphop test: 23 train: 77 all: 100 jazz test: 18 train: 82 all: 100 metal test: 20 train: 80 all: 100 pop test: 24 train: 76 all: 100 reggae test: 15 train: 85 all: 100 rock test: 25 train: 75 all: 100
3. Wizualizacja danych
Boxplot dla tempa gatunków
Jedną z najciekawszych i najbardziej intuicyjnych wartości mierzalnych dla poszczególnych utworów jest tempo. Parametr ten został przedstawiony przy pomocy wykresu pudełkowego w odniesieniu do wspomnianych wcześniej 10 gatunków muzycznych.
Ze zgromadzonych danych jednoznacznie wynika, że najwyższą średnią wartość dla tempa mają utwory z gatunku Reggee, zaś na drugim i trzecim miejscu znajdują się odpowiednio muzyka klasyczna oraz blues. Podczas gdy najniższe wartości mają gatunki hip-hop oraz pop.
Z kolei największe rozbieżności pomiędzy wartościami zauważalne są w przypadku muzyki klasycznej, country i metalu, chociaż najwięcej obserwacji odstających pojawia się w przypadku hiphopu oraz popu.
f, ax = plt.subplots(figsize=(16, 9));
sns.boxplot(x = "label", y = "tempo", data = data[["label", "tempo"]], palette = 'pastel');
plt.title('Zależność pomiędzy tempem a gatunkiem', fontsize = 25)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 10);
plt.xlabel("Genre", fontsize = 15)
plt.ylabel("Tempo", fontsize = 15);
Boxplot dla średnich melowych współczynników cepstralnych sygnału dla poszczególnych gatunków
Interesujące wyniki pojawiły się także na wykresie pokazującym zależność pomiędzy MFCC mean, czyli średnimi wartościami dla melowych współczynników cepstralnych sygnału a gatunkami muzycznimi.
Najwyższe wartości MFCC_mean dotyczą metalu oraz bluesa, podczas gdy najniższe wartości uzyskiwane są w przypadu popu i muzyki klasycznej. Z kolei najwięcej obserwacji odstających pojawia się w przypadku reggae.
f, ax = plt.subplots(figsize=(16, 9));
sns.boxplot(x = "label", y = "mfcc4_mean", data = data[["label", "mfcc4_mean"]], palette = 'pastel');
plt.title('Zależność między MFCC a gatunkiem', fontsize = 25)
plt.xticks(fontsize = 14)
plt.yticks(fontsize = 10);
plt.xlabel("Genre", fontsize = 15)
plt.ylabel("mfcc4_mean4", fontsize = 15);
Korelacja między cechami średnimi
W procesie badania zależności pomiędzy dostępnymi cechami wykorzystana została mapa ciepła, która jednak pokazała, że w wielu przypadkach korelacje nie zachodzą, co jest szczególnie widoczne w przypadku średniej częstotliwości melodycznej cepstrum2 (mfcc2_mean), a jeżeli takowe korelacje zachodza to mają stosunkowo niewielkie wartości.Występowanie zależności widać w górnej oraz środkowej częsci mapy.
Należy jednak podkreślić, ze takie wartości nie są dobre, ponieważ niewielkie zrożnicowanie odchyleń może mieć istotny wpływ na znajdowanie korelacji, a co za tym idzie skuteczność modelu!
mean_cols = [col for col in data.columns if 'mean' in col]
mean_correlation = data[mean_cols].corr()
mask = np.triu(np.ones_like(mean_correlation, dtype=bool))
f, ax = plt.subplots(figsize=(16, 11))
cmap = sns.diverging_palette(150, 275, as_cmap=True, s = 90, l = 45, n = 5)
sns.heatmap(mean_correlation, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5)
plt.title('Correlation Heatmap (means)', fontsize = 20)
plt.xticks(fontsize = 10)
plt.yticks(fontsize = 10);
Korelacja między cechami wariancji
Odwrotna sytuacja ma miejsce w przypadku mapy ciepła dla cech wariancji, w przypadku ktorych korelacja nie zachodzi wyłącznie dla dwóch parametrów czyli harmony i perceptr w środkowej cześci wykresu. Z kolei stosunkow wysokie wartości korelacji można zaobserwować dla parametrów "skrajnych", czyli pierwszych i ostatnich na liście parametrów.
var_cols = [col for col in data.columns if 'var' in col]
var_correlation = data[var_cols].corr()
mask = np.triu(np.ones_like(var_correlation, dtype=bool))
f, ax = plt.subplots(figsize=(16, 11))
cmap = sns.diverging_palette(240, 10, as_cmap=True, s = 90, l = 45, n = 5)
sns.heatmap(var_correlation, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5)
plt.title('Correlation Heatmap (vars)', fontsize = 20)
plt.xticks(fontsize = 10)
plt.yticks(fontsize = 10);
Cechy średnie dają lepsze rezultaty niż cechy wariancji ze względu na mniejszą korelację pomiędzy poszczególnymi parametrami.(https://datascience.stackexchange.com/questions/9087/correlation-and-naive-bayes).
W toku przeprowadzanych testów okazało sie, że dokładność stworzonego i wytrenowanego modelu zależy od rodzaju cech. W przypadku cech wariancji dokładność jest niższa niż w przypadku cech średnich. Z kolei najwyższą dokładność udało się uzyskać poprzez wykorzystanie kombinacji 8 różnych kolumn.
- dla var_cols accuracy = 0.3875,
- dla mean_cols accuracy = 0.4375,
- dla ['mfcc4_mean', 'mfcc12_mean', 'mfcc9_var', 'mfcc1_mean', 'rms_mean', 'chroma_stft_mean', 'mfcc6_var', 'mfcc9_mean'] accuracy = 0.56125
Równocześnie uzyskane wyniki mogłyby mieć zdecydowanie wyższą dokładność jednak ograniczeniem okazał się specyfika samego datasetu, który posiada niewielkie zróżnicowanie wartości cech i niewielką korelację pomiędzy poszczególnymi cechami!
Wykres punktowy przedstawiający zależność pomiędzy krótkotrwałą transformatą Fouriera a częstotliwością melodyczną cepstrum
Wykres ten jednoznacznie pokazuje, że chociaż pojawia się pewna pula obserwacji odstających to jednak wraz ze wzrosetem wartości melowego współczynnika cepstralnego sygnału rosną wartości krótkotrwałej transformaty Fouriera. Tym samym zależność pomiędzy tymi dwoma warościami, w ogólność, ma charakter liniowy.
fig = plt.figure(figsize=(10,10))
chart = fig.add_subplot()
ax = fig.add_subplot()
colors = ['red', 'green', 'blue', 'brown','purple', 'gray', 'pink', 'black', 'yellow', 'orange']
for genre in genre_dict:
genre_data = data[data["genre"]==genre_dict[genre]]
ax.scatter(genre_data['chroma_stft_mean'],genre_data['mfcc12_mean'], c=colors[genre_dict[genre]-1])
plt.show()
<ipython-input-10-3dd1710c0643>:3: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance. In a future version, a new instance will always be created and returned. Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance. ax = fig.add_subplot()
4. Wykorzystanie algorytmu Bayesa
class NaiveBayesContinues:
def __init__(self, X, Y):
self.classes = Y.unique()
self.priors = [] # prawdopodobieństwo każdej z klas
self.stds = [] #lista odchyleń standardowych każdej z cech dla każdej z klas
self.means = [] #lista średnich dla każdej z cech dla każdej z klas
for c in self.classes:
x_with_c_class = X[c == Y]
self.priors.append(len(x_with_c_class) / len(X))
self.means.append(x_with_c_class.mean(axis=0))
self.stds.append(x_with_c_class.std(axis=0))
def predict(self, X, display_results=False):
y_preds = []
for x in X:
posteriors = []
for i, c in enumerate(self.classes):
prior = self.priors[i] # prawdopodobieństwo dla rozpatrywanej klasy
mean = self.means[i] # średnia cech dla rozpatrywanej klasy
std = self.stds[i] # odchylenie standardowe cech dla rozpatrywanej klasy
posterior = 1 #P(X1|Yi)*P(X2|Yi)*P(X3|Yi)...
for j, feature in (enumerate(x)):
P_X_yi = np.exp((-(feature - mean[j]) ** 2) / (2 * std[j] ** 2)) / np.sqrt(2 * np.pi * std[j] ** 2) #P(Xj|Yi)
posterior *= P_X_yi
posterior = (posterior * prior) #P(Yi)P(X1|Yi)*P(X2|Yi)*P(X3|Yi)...
posteriors.append(posterior)
if(display_results):
print("posteriors")
print(posteriors)
print(np.argmax(posteriors))
y_pred = self.classes[np.argmax(posteriors)] # Wzięcie klasy z największym prawdopodobieństem
y_preds.append(y_pred)
return y_preds
Gausowski Naiwny Bayes Stosowany w przypadku pracy na danych o charakterze ciągłym.
W procesie trenowania i testowania modelu wykorzystany został skrypt losujący kolumny i zapisujący uzyskiwane wartości accuracy w celu znalezienia najbardziej efektywnej kombinacji cech. W ten sposób wybranych zostało 8 cech, w tym sześć cech należących do kategorii średnich i dwie do wariancji.
X_train_np = X_train[['mfcc4_mean', 'mfcc12_mean', 'mfcc9_var', 'mfcc1_mean', 'rms_mean', 'chroma_stft_mean', 'mfcc6_var', 'mfcc9_mean']].to_numpy()
X_test_np = X_test[['mfcc4_mean', 'mfcc12_mean', 'mfcc9_var', 'mfcc1_mean', 'rms_mean', 'chroma_stft_mean', 'mfcc6_var', 'mfcc9_mean']].to_numpy()
model = NaiveBayesContinues(X_train_np, Y_train)
Y_train_predicted = model.predict(X_train_np[:1])
Ewaluacja
Zbiór trenujący
Y_train_predicted = model.predict(X_train_np)
cm = confusion_matrix(Y_train, Y_train_predicted)
ac = accuracy_score(Y_train, Y_train_predicted)
print("(Train data) Confusion matrix:")
display(cm)
print("(Train data) Accuracy:")
print(ac)
(Train data) Confusion matrix:
array([[27, 3, 11, 3, 1, 8, 18, 0, 10, 4], [ 1, 75, 2, 0, 0, 10, 0, 0, 1, 0], [10, 1, 29, 9, 1, 4, 2, 8, 6, 3], [ 2, 0, 2, 39, 1, 1, 12, 12, 5, 4], [ 0, 0, 0, 8, 36, 0, 11, 13, 8, 1], [ 8, 20, 2, 0, 0, 45, 1, 0, 0, 6], [ 1, 0, 0, 3, 2, 0, 71, 0, 0, 3], [ 1, 1, 2, 2, 5, 0, 0, 63, 2, 0], [ 1, 0, 8, 8, 6, 3, 3, 5, 48, 3], [ 4, 0, 10, 15, 0, 6, 14, 5, 5, 16]], dtype=int64)
(Train data) Accuracy: 0.56125
Zbiór testowy
Y_test_predicted = model.predict(X_test_np)
cm = confusion_matrix(Y_test, Y_test_predicted)
ac = accuracy_score(Y_test, Y_test_predicted)
print("Confusion matrix:")
display(cm)
print("Accuracy:")
print(ac)
Confusion matrix:
array([[ 5, 0, 2, 0, 0, 2, 4, 0, 1, 1], [ 0, 8, 0, 0, 0, 3, 0, 0, 0, 0], [ 7, 0, 6, 6, 0, 4, 1, 0, 1, 2], [ 0, 0, 2, 7, 1, 0, 2, 4, 5, 1], [ 0, 0, 0, 2, 10, 0, 6, 1, 4, 0], [ 0, 3, 0, 1, 0, 14, 0, 0, 0, 0], [ 0, 0, 0, 0, 1, 0, 17, 0, 2, 0], [ 1, 0, 0, 2, 1, 0, 0, 18, 1, 1], [ 0, 1, 0, 0, 3, 1, 0, 4, 6, 0], [ 5, 0, 3, 5, 1, 1, 5, 1, 0, 4]], dtype=int64)
Accuracy: 0.475
Przykładowe porównania
for i in range(10):
print(f"Y: {Y_test.to_numpy()[i]}\tPredicted: {Y_test_predicted[i]}")
Y: 10 Predicted: 10 Y: 9 Predicted: 8 Y: 3 Predicted: 1 Y: 6 Predicted: 6 Y: 7 Predicted: 7 Y: 10 Predicted: 7 Y: 1 Predicted: 1 Y: 3 Predicted: 6 Y: 4 Predicted: 4 Y: 8 Predicted: 10
Bayes z wykorzystaniem gotowej biblioteki
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix, accuracy_score
import pandas as pd
import numpy as np
import pickle, os
import typing
class Bayes:
def __init__(self):
self.classifier = GaussianNB()
def train(self, X: pd.DataFrame, Y: pd.Series) -> None:
self.classifier.fit(X, Y)
def predict(self, X: pd.DataFrame) -> np.ndarray:
predictions = self.classifier.predict(X)
return predictions
def eval(self, Y: pd.Series, Y_pred: np.ndarray) -> typing.Tuple[np.ndarray, np.float64]:
cm = confusion_matrix(Y, Y_pred)
ac = accuracy_score(Y, Y_pred)
return (cm, ac)
bayes = Bayes()
bayes.train(X_train_np, Y_train)
Y_predicted = bayes.predict(X_train)
eval_result = bayes.eval(Y_train, Y_predicted)
print("Train:")
print(eval_result[1])
Y_predicted = bayes.predict(X_test)
eval_result = bayes.eval(Y_test, Y_predicted)
print("Test:")
print(eval_result[1])
[1;31m---------------------------------------------------------------------------[0m [1;31mValueError[0m Traceback (most recent call last) [1;32m<ipython-input-18-86955df8a2b3>[0m in [0;36m<module>[1;34m[0m [0;32m 2[0m [0mbayes[0m[1;33m.[0m[0mtrain[0m[1;33m([0m[0mX_train_np[0m[1;33m,[0m [0mY_train[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 3[0m [1;33m[0m[0m [1;32m----> 4[1;33m [0mY_predicted[0m [1;33m=[0m [0mbayes[0m[1;33m.[0m[0mpredict[0m[1;33m([0m[0mX_train[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 5[0m [0meval_result[0m [1;33m=[0m [0mbayes[0m[1;33m.[0m[0meval[0m[1;33m([0m[0mY_train[0m[1;33m,[0m [0mY_predicted[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 6[0m [0mprint[0m[1;33m([0m[1;34m"Train:"[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m<ipython-input-17-d1179e864a3a>[0m in [0;36mpredict[1;34m(self, X)[0m [0;32m 16[0m [1;33m[0m[0m [0;32m 17[0m [1;32mdef[0m [0mpredict[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mX[0m[1;33m:[0m [0mpd[0m[1;33m.[0m[0mDataFrame[0m[1;33m)[0m [1;33m->[0m [0mnp[0m[1;33m.[0m[0mndarray[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 18[1;33m [0mpredictions[0m [1;33m=[0m [0mself[0m[1;33m.[0m[0mclassifier[0m[1;33m.[0m[0mpredict[0m[1;33m([0m[0mX[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 19[0m [1;32mreturn[0m [0mpredictions[0m[1;33m[0m[1;33m[0m[0m [0;32m 20[0m [1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\sklearn\naive_bayes.py[0m in [0;36mpredict[1;34m(self, X)[0m [0;32m 75[0m """ [0;32m 76[0m [0mcheck_is_fitted[0m[1;33m([0m[0mself[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 77[1;33m [0mX[0m [1;33m=[0m [0mself[0m[1;33m.[0m[0m_check_X[0m[1;33m([0m[0mX[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 78[0m [0mjll[0m [1;33m=[0m [0mself[0m[1;33m.[0m[0m_joint_log_likelihood[0m[1;33m([0m[0mX[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 79[0m [1;32mreturn[0m [0mself[0m[1;33m.[0m[0mclasses_[0m[1;33m[[0m[0mnp[0m[1;33m.[0m[0margmax[0m[1;33m([0m[0mjll[0m[1;33m,[0m [0maxis[0m[1;33m=[0m[1;36m1[0m[1;33m)[0m[1;33m][0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\sklearn\naive_bayes.py[0m in [0;36m_check_X[1;34m(self, X)[0m [0;32m 214[0m [1;33m[0m[0m [0;32m 215[0m [1;32mdef[0m [0m_check_X[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mX[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 216[1;33m [1;32mreturn[0m [0mcheck_array[0m[1;33m([0m[0mX[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 217[0m [1;33m[0m[0m [0;32m 218[0m [1;33m@[0m[0mstaticmethod[0m[1;33m[0m[1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\sklearn\utils\validation.py[0m in [0;36minner_f[1;34m(*args, **kwargs)[0m [0;32m 70[0m FutureWarning) [0;32m 71[0m [0mkwargs[0m[1;33m.[0m[0mupdate[0m[1;33m([0m[1;33m{[0m[0mk[0m[1;33m:[0m [0marg[0m [1;32mfor[0m [0mk[0m[1;33m,[0m [0marg[0m [1;32min[0m [0mzip[0m[1;33m([0m[0msig[0m[1;33m.[0m[0mparameters[0m[1;33m,[0m [0margs[0m[1;33m)[0m[1;33m}[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 72[1;33m [1;32mreturn[0m [0mf[0m[1;33m([0m[1;33m**[0m[0mkwargs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 73[0m [1;32mreturn[0m [0minner_f[0m[1;33m[0m[1;33m[0m[0m [0;32m 74[0m [1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\sklearn\utils\validation.py[0m in [0;36mcheck_array[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)[0m [0;32m 596[0m [0marray[0m [1;33m=[0m [0marray[0m[1;33m.[0m[0mastype[0m[1;33m([0m[0mdtype[0m[1;33m,[0m [0mcasting[0m[1;33m=[0m[1;34m"unsafe"[0m[1;33m,[0m [0mcopy[0m[1;33m=[0m[1;32mFalse[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 597[0m [1;32melse[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 598[1;33m [0marray[0m [1;33m=[0m [0mnp[0m[1;33m.[0m[0masarray[0m[1;33m([0m[0marray[0m[1;33m,[0m [0morder[0m[1;33m=[0m[0morder[0m[1;33m,[0m [0mdtype[0m[1;33m=[0m[0mdtype[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 599[0m [1;32mexcept[0m [0mComplexWarning[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 600[0m raise ValueError("Complex data not supported\n" [1;32m~\anaconda3\lib\site-packages\numpy\core\_asarray.py[0m in [0;36masarray[1;34m(a, dtype, order, like)[0m [0;32m 100[0m [1;32mreturn[0m [0m_asarray_with_like[0m[1;33m([0m[0ma[0m[1;33m,[0m [0mdtype[0m[1;33m=[0m[0mdtype[0m[1;33m,[0m [0morder[0m[1;33m=[0m[0morder[0m[1;33m,[0m [0mlike[0m[1;33m=[0m[0mlike[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 101[0m [1;33m[0m[0m [1;32m--> 102[1;33m [1;32mreturn[0m [0marray[0m[1;33m([0m[0ma[0m[1;33m,[0m [0mdtype[0m[1;33m,[0m [0mcopy[0m[1;33m=[0m[1;32mFalse[0m[1;33m,[0m [0morder[0m[1;33m=[0m[0morder[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 103[0m [1;33m[0m[0m [0;32m 104[0m [1;33m[0m[0m [1;32m~\anaconda3\lib\site-packages\pandas\core\generic.py[0m in [0;36m__array__[1;34m(self, dtype)[0m [0;32m 1897[0m [1;33m[0m[0m [0;32m 1898[0m [1;32mdef[0m [0m__array__[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mdtype[0m[1;33m=[0m[1;32mNone[0m[1;33m)[0m [1;33m->[0m [0mnp[0m[1;33m.[0m[0mndarray[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m-> 1899[1;33m [1;32mreturn[0m [0mnp[0m[1;33m.[0m[0masarray[0m[1;33m([0m[0mself[0m[1;33m.[0m[0m_values[0m[1;33m,[0m [0mdtype[0m[1;33m=[0m[0mdtype[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 1900[0m [1;33m[0m[0m [0;32m 1901[0m def __array_wrap__( [1;32m~\anaconda3\lib\site-packages\numpy\core\_asarray.py[0m in [0;36masarray[1;34m(a, dtype, order, like)[0m [0;32m 100[0m [1;32mreturn[0m [0m_asarray_with_like[0m[1;33m([0m[0ma[0m[1;33m,[0m [0mdtype[0m[1;33m=[0m[0mdtype[0m[1;33m,[0m [0morder[0m[1;33m=[0m[0morder[0m[1;33m,[0m [0mlike[0m[1;33m=[0m[0mlike[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 101[0m [1;33m[0m[0m [1;32m--> 102[1;33m [1;32mreturn[0m [0marray[0m[1;33m([0m[0ma[0m[1;33m,[0m [0mdtype[0m[1;33m,[0m [0mcopy[0m[1;33m=[0m[1;32mFalse[0m[1;33m,[0m [0morder[0m[1;33m=[0m[0morder[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 103[0m [1;33m[0m[0m [0;32m 104[0m [1;33m[0m[0m [1;31mValueError[0m: could not convert string to float: 'metal'
# skrypt losujacy kolumny ze zbioru i sprawdzajacy accuracy na zbiorze trenujacym
for i in range(100):
X = data.drop(["genre"], axis=1)
X_rand = X.sample(n=10, axis='columns')
Y = data["genre"]
X_train, X_test, Y_train, Y_test = train_test_split(X_rand, Y, test_size = 0.20, random_state = False)
model = GaussianNB()
model.fit(X_train, Y_train)
Y_train_predicted = model.predict(X_train)
ac = accuracy_score(Y_train, Y_train_predicted)
filename = 'accuracy.txt'
if os.path.exists(filename):
append_write = 'a'
else:
append_write = 'w'
acc_random = open(filename, append_write)
acc_random.write(str(ac) + " " + str(list(X_rand.columns)) + '\n')
acc_random.close()
!sort -k1,1nr -k2,2 accuracy.txt