{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Skład grupy:\n", "- Nowak Ania,\n", "- Łaźna Patrycja,\n", "- Bregier Damian" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#!pip install pandas==1.2.4\n", "#!pip install numpy==1.20.3\n", "#!pip install sklearn==0.0\n", "\n", "from sklearn.model_selection import train_test_split\n", "import pandas as pd\n", "import numpy as np\n", "import typing\n", "import os, pickle\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns; sns.set()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 0. Podstawowe informacje o zbiorze danych" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "W projekcie wykorzystany został GTZAN Dataset poruszający problem wieloklasowej klasyfikacji danych na przykładzie gatunków muzycznych. Zbiór ten składa się z 10 gatunków obejmujacych: blues, muzykę klasyczną, country, disco, hip-hop, jazz, pop, reggae oraz rock. Każdy ze wspomnianych gatunków jest reprezentowany przez 100 plików audio o długości 30 sekund, a same próbki były zbierane w latach 2000-2001 ze zdyfersyfikowanych źródeł obejmujących: stacje radiowe, prywatne płyty CD oraz nagrania własne.\n", "\n", "Zbiór danych jest niezwykle bogaty i rozbudowany, ponieważ do każdego utworu zostało przypisanych 60 unikalnych parametrów. Parametry te obejmują takie dane jak: długość utworu, etykietę z nazwą gatunku, tempo, harmoniczność, variancję czy częstotliwość melodyczną (MFCC).\n", "\n", "Dokładne dane na temat tego zbioru danych można znaleźć pod adresem: https://www.kaggle.com/andradaolteanu/gtzan-dataset-music-genre-classification\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Wczytywanie i normalizacja danych" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Słownik zawierający 10 gatunków muzycznych, które zostały sparowane z\n", "# odpowiadającymi im wartościami numerycznymi\n", "genre_dict = {\n", " \"blues\" : 1,\n", " \"classical\" : 2,\n", " \"country\" : 3,\n", " \"disco\" : 4,\n", " \"hiphop\" : 5,\n", " \"jazz\" : 6,\n", " \"metal\" : 7,\n", " \"pop\" : 8,\n", " \"reggae\" : 9,\n", " \"rock\" : 10\n", "}\n", "# nazwa pliku w którym umieszczane są parametry po wstępnym przetworzeniu\n", "filename = 'music_genre.csv'\n", "model_path = 'model.model'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Preparing data...\n" ] }, { "data": { "text/html": [ "
\n", " | genre | \n", "chroma_stft_mean | \n", "chroma_stft_var | \n", "rms_mean | \n", "rms_var | \n", "spectral_centroid_mean | \n", "spectral_centroid_var | \n", "spectral_bandwidth_mean | \n", "spectral_bandwidth_var | \n", "rolloff_mean | \n", "... | \n", "mfcc16_var | \n", "mfcc17_mean | \n", "mfcc17_var | \n", "mfcc18_mean | \n", "mfcc18_var | \n", "mfcc19_mean | \n", "mfcc19_var | \n", "mfcc20_mean | \n", "mfcc20_var | \n", "label | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1 | \n", "0.350088 | \n", "0.088757 | \n", "0.130228 | \n", "0.002827 | \n", "1784.165850 | \n", "1.297741e+05 | \n", "2002.449060 | \n", "85882.761315 | \n", "3805.839606 | \n", "... | \n", "52.420910 | \n", "-1.690215 | \n", "36.524071 | \n", "-0.408979 | \n", "41.597103 | \n", "-2.303523 | \n", "55.062923 | \n", "1.221291 | \n", "46.936035 | \n", "blues | \n", "
1 | \n", "1 | \n", "0.340914 | \n", "0.094980 | \n", "0.095948 | \n", "0.002373 | \n", "1530.176679 | \n", "3.758501e+05 | \n", "2039.036516 | \n", "213843.755497 | \n", "3550.522098 | \n", "... | \n", "55.356403 | \n", "-0.731125 | \n", "60.314529 | \n", "0.295073 | \n", "48.120598 | \n", "-0.283518 | \n", "51.106190 | \n", "0.531217 | \n", "45.786282 | \n", "blues | \n", "
2 | \n", "1 | \n", "0.363637 | \n", "0.085275 | \n", "0.175570 | \n", "0.002746 | \n", "1552.811865 | \n", "1.564676e+05 | \n", "1747.702312 | \n", "76254.192257 | \n", "3042.260232 | \n", "... | \n", "40.598766 | \n", "-7.729093 | \n", "47.639427 | \n", "-1.816407 | \n", "52.382141 | \n", "-3.439720 | \n", "46.639660 | \n", "-2.231258 | \n", "30.573025 | \n", "blues | \n", "
3 | \n", "1 | \n", "0.404785 | \n", "0.093999 | \n", "0.141093 | \n", "0.006346 | \n", "1070.106615 | \n", "1.843559e+05 | \n", "1596.412872 | \n", "166441.494769 | \n", "2184.745799 | \n", "... | \n", "44.427753 | \n", "-3.319597 | \n", "50.206673 | \n", "0.636965 | \n", "37.319130 | \n", "-0.619121 | \n", "37.259739 | \n", "-3.407448 | \n", "31.949339 | \n", "blues | \n", "
4 | \n", "1 | \n", "0.308526 | \n", "0.087841 | \n", "0.091529 | \n", "0.002303 | \n", "1835.004266 | \n", "3.433999e+05 | \n", "1748.172116 | \n", "88445.209036 | \n", "3579.757627 | \n", "... | \n", "86.099236 | \n", "-5.454034 | \n", "75.269707 | \n", "-0.916874 | \n", "53.613918 | \n", "-4.404827 | \n", "62.910812 | \n", "-11.703234 | \n", "55.195160 | \n", "blues | \n", "
5 | \n", "1 | \n", "0.302456 | \n", "0.087532 | \n", "0.103494 | \n", "0.003981 | \n", "1831.993940 | \n", "1.030482e+06 | \n", "1729.653287 | \n", "201910.508633 | \n", "3481.517592 | \n", "... | \n", "72.549225 | \n", "-1.838263 | \n", "68.702026 | \n", "-2.783800 | \n", "42.447453 | \n", "-3.047909 | \n", "39.808784 | \n", "-8.109991 | \n", "46.311005 | \n", "blues | \n", "
6 | \n", "1 | \n", "0.291328 | \n", "0.093981 | \n", "0.141874 | \n", "0.008803 | \n", "1459.366472 | \n", "4.378594e+05 | \n", "1389.009131 | \n", "185023.239545 | \n", "2795.610963 | \n", "... | \n", "83.248245 | \n", "-10.913176 | \n", "56.902153 | \n", "-6.971336 | \n", "38.231800 | \n", "-3.436505 | \n", "48.235741 | \n", "-6.483466 | \n", "70.170364 | \n", "blues | \n", "
7 | \n", "1 | \n", "0.307955 | \n", "0.092903 | \n", "0.131822 | \n", "0.005531 | \n", "1451.667066 | \n", "4.495682e+05 | \n", "1577.270941 | \n", "168211.938804 | \n", "2954.836760 | \n", "... | \n", "70.438438 | \n", "-10.568935 | \n", "52.090893 | \n", "-10.784515 | \n", "60.461330 | \n", "-4.690678 | \n", "65.547516 | \n", "-8.630722 | \n", "56.401436 | \n", "blues | \n", "
8 | \n", "1 | \n", "0.408879 | \n", "0.086512 | \n", "0.142416 | \n", "0.001507 | \n", "1719.368948 | \n", "1.632828e+05 | \n", "2031.740381 | \n", "105542.718193 | \n", "3782.316288 | \n", "... | \n", "50.563751 | \n", "-7.041824 | \n", "28.894934 | \n", "2.695248 | \n", "36.889568 | \n", "3.412305 | \n", "33.698597 | \n", "-2.715692 | \n", "36.418430 | \n", "blues | \n", "
9 | \n", "1 | \n", "0.273950 | \n", "0.092316 | \n", "0.081314 | \n", "0.004347 | \n", "1817.150863 | \n", "2.982361e+05 | \n", "1973.773306 | \n", "114070.112591 | \n", "3943.490565 | \n", "... | \n", "59.314602 | \n", "-1.916804 | \n", "58.418438 | \n", "-2.292661 | \n", "83.205231 | \n", "2.881967 | \n", "77.082222 | \n", "-4.235203 | \n", "91.468811 | \n", "blues | \n", "
10 rows × 59 columns
\n", "\n", " | chroma_stft_mean | \n", "chroma_stft_var | \n", "rms_mean | \n", "rms_var | \n", "spectral_centroid_mean | \n", "spectral_centroid_var | \n", "spectral_bandwidth_mean | \n", "spectral_bandwidth_var | \n", "rolloff_mean | \n", "rolloff_var | \n", "... | \n", "mfcc16_var | \n", "mfcc17_mean | \n", "mfcc17_var | \n", "mfcc18_mean | \n", "mfcc18_var | \n", "mfcc19_mean | \n", "mfcc19_var | \n", "mfcc20_mean | \n", "mfcc20_var | \n", "label | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
687 | \n", "0.516547 | \n", "0.072241 | \n", "0.267380 | \n", "0.001175 | \n", "3338.581900 | \n", "172002.893292 | \n", "2697.128636 | \n", "45771.294278 | \n", "6670.863091 | \n", "3.556853e+05 | \n", "... | \n", "37.339474 | \n", "-8.121326 | \n", "33.968277 | \n", "4.910113 | \n", "42.063385 | \n", "-2.474697 | \n", "35.162354 | \n", "3.192656 | \n", "36.478157 | \n", "metal | \n", "
500 | \n", "0.344511 | \n", "0.085002 | \n", "0.046747 | \n", "0.001542 | \n", "1503.869486 | \n", "554576.511533 | \n", "1754.216082 | \n", "283554.933422 | \n", "2799.283099 | \n", "2.685679e+06 | \n", "... | \n", "50.311016 | \n", "-1.503434 | \n", "41.141155 | \n", "0.221949 | \n", "55.707256 | \n", "-1.991485 | \n", "50.006485 | \n", "-3.353825 | \n", "49.906403 | \n", "jazz | \n", "
332 | \n", "0.368345 | \n", "0.090390 | \n", "0.111073 | \n", "0.004402 | \n", "2446.919077 | \n", "490397.099115 | \n", "2449.159840 | \n", "215375.540632 | \n", "4958.057490 | \n", "2.650020e+06 | \n", "... | \n", "78.892769 | \n", "-1.054999 | \n", "79.877068 | \n", "4.496278 | \n", "112.834435 | \n", "-0.978958 | \n", "75.059898 | \n", "-5.256925 | \n", "120.275269 | \n", "disco | \n", "
979 | \n", "0.360042 | \n", "0.083953 | \n", "0.116724 | \n", "0.000789 | \n", "2148.410463 | \n", "253618.158995 | \n", "2107.165355 | \n", "72155.551685 | \n", "4479.264304 | \n", "9.787046e+05 | \n", "... | \n", "37.060532 | \n", "-13.479134 | \n", "50.848667 | \n", "3.308529 | \n", "47.726006 | \n", "-3.704957 | \n", "56.781952 | \n", "1.085497 | \n", "54.243389 | \n", "rock | \n", "
817 | \n", "0.425788 | \n", "0.091852 | \n", "0.139799 | \n", "0.003601 | \n", "1803.774378 | \n", "659241.158049 | \n", "1973.418903 | \n", "201432.199120 | \n", "3777.969679 | \n", "2.632339e+06 | \n", "... | \n", "64.068756 | \n", "-2.219202 | \n", "99.249870 | \n", "5.304260 | \n", "64.088127 | \n", "-6.597187 | \n", "62.661850 | \n", "-2.923168 | \n", "67.490440 | \n", "reggae | \n", "
620 | \n", "0.495959 | \n", "0.072854 | \n", "0.117362 | \n", "0.000867 | \n", "2657.912854 | \n", "189139.438926 | \n", "2345.662472 | \n", "32730.579626 | \n", "5358.261979 | \n", "5.918222e+05 | \n", "... | \n", "27.937113 | \n", "-10.676390 | \n", "26.519361 | \n", "3.875155 | \n", "25.613684 | \n", "-4.943561 | \n", "24.334734 | \n", "3.255899 | \n", "25.199259 | \n", "metal | \n", "
814 | \n", "0.395137 | \n", "0.093939 | \n", "0.114246 | \n", "0.004025 | \n", "1716.249594 | \n", "920189.339374 | \n", "2062.885827 | \n", "358557.016423 | \n", "3790.901258 | \n", "4.734865e+06 | \n", "... | \n", "66.090370 | \n", "-4.590122 | \n", "72.595345 | \n", "4.261040 | \n", "63.185764 | \n", "-2.127876 | \n", "50.693245 | \n", "-3.665569 | \n", "89.750290 | \n", "reggae | \n", "
516 | \n", "0.249535 | \n", "0.087563 | \n", "0.060560 | \n", "0.001276 | \n", "1465.857446 | \n", "143302.098295 | \n", "1738.858902 | \n", "58868.399307 | \n", "2822.406728 | \n", "7.392007e+05 | \n", "... | \n", "109.811813 | \n", "-0.027696 | \n", "113.660950 | \n", "2.098475 | \n", "160.025497 | \n", "1.109709 | \n", "136.810165 | \n", "2.935807 | \n", "95.914490 | \n", "jazz | \n", "
518 | \n", "0.353474 | \n", "0.087755 | \n", "0.052264 | \n", "0.000316 | \n", "1993.352766 | \n", "64753.479332 | \n", "2127.165109 | \n", "36027.039069 | \n", "4248.194549 | \n", "3.987029e+05 | \n", "... | \n", "57.230133 | \n", "-1.110214 | \n", "48.080849 | \n", "-0.784249 | \n", "57.033504 | \n", "-2.984207 | \n", "55.737625 | \n", "0.350456 | \n", "64.126846 | \n", "jazz | \n", "
940 | \n", "0.416089 | \n", "0.087772 | \n", "0.142935 | \n", "0.003150 | \n", "3009.958707 | \n", "435134.775688 | \n", "2778.049758 | \n", "135548.871316 | \n", "6131.200719 | \n", "1.788624e+06 | \n", "... | \n", "42.315434 | \n", "-3.953057 | \n", "48.761936 | \n", "-3.092345 | \n", "49.514446 | \n", "-2.731183 | \n", "58.219994 | \n", "-0.909785 | \n", "63.111858 | \n", "rock | \n", "
10 rows × 58 columns
\n", "