diff --git a/.gitignore b/.gitignore
index 46117a8..dfdfd6a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,5 @@ venv*
.vscode*
__pycache__*
music_genre.csv
-music_genre.model
\ No newline at end of file
+*.model
+.ipynb_checkpoints*
\ No newline at end of file
diff --git a/Bayes.ipynb b/Bayes.ipynb
new file mode 100644
index 0000000..b0ba291
--- /dev/null
+++ b/Bayes.ipynb
@@ -0,0 +1,1121 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe)\n",
+ "Zasady zaliczenia: 40 punktów podzielone następująco:\n",
+ "- 10 pkt - prezentacja projektu\n",
+ "- 15 pkt - implementacja, w tym:\n",
+ "- 5 pkt - zgodność z tematem,\n",
+ "- 5 pkt - jakość kodu,\n",
+ "- 5 pkt - poprawność implementacji\n",
+ "- 10 pkt - efekt \"wow\"\n",
+ "- 5 pkt - aktywność wszystkich członków grupy\n",
+ "\n",
+ "Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe). Implementacja powinna założyć, że cechy są ciągłe (do wyboru rozkład normalny i jądrowe wygładzenie). Na wejściu oczekiwany jest zbiór, który zawiera p-cech ciągłych, wektor etykiet oraz wektor prawdopodobieństw a priori dla klas. Na wyjściu otrzymujemy prognozowane etykiety oraz prawdopodobieństwa a posteriori. Dodatkową wartością może być wizualizacja obszarów decyzyjnych w przypadku dwóch cech.\n",
+ "\n",
+ "```Termin oddania na Moodle: do 31 maja. Prezentacja projektów 1 czerwca na ćwiczeniach.```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: pandas==1.2.4 in c:\\users\\annad\\anaconda3\\lib\\site-packages (1.2.4)\n",
+ "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from pandas==1.2.4) (2.8.1)\n",
+ "Requirement already satisfied: numpy>=1.16.5 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from pandas==1.2.4) (1.20.3)\n",
+ "Requirement already satisfied: pytz>=2017.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from pandas==1.2.4) (2020.1)\n",
+ "Requirement already satisfied: six>=1.5 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from python-dateutil>=2.7.3->pandas==1.2.4) (1.15.0)\n",
+ "Requirement already satisfied: numpy==1.20.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (1.20.3)\n",
+ "Requirement already satisfied: sklearn==0.0 in c:\\users\\annad\\anaconda3\\lib\\site-packages (0.0)\n",
+ "Requirement already satisfied: scikit-learn in c:\\users\\annad\\appdata\\roaming\\python\\python38\\site-packages (from sklearn==0.0) (0.24.2)\n",
+ "Requirement already satisfied: joblib>=0.11 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (0.17.0)\n",
+ "Requirement already satisfied: numpy>=1.13.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (1.20.3)\n",
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (2.1.0)\n",
+ "Requirement already satisfied: scipy>=0.19.1 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (1.5.2)\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install pandas==1.2.4\n",
+ "!pip install numpy==1.20.3\n",
+ "!pip install sklearn==0.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import typing\n",
+ "import os, pickle\n",
+ "from sklearn.naive_bayes import GaussianNB\n",
+ "from sklearn.metrics import confusion_matrix, accuracy_score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Wczytywanie i normalizacja danych"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Stałe\n",
+ "genre_dict = {\n",
+ " \"blues\" : 1,\n",
+ " \"classical\" : 2,\n",
+ " \"country\" : 3,\n",
+ " \"disco\" : 4,\n",
+ " \"hiphop\" : 5,\n",
+ " \"jazz\" : 6,\n",
+ " \"metal\" : 7,\n",
+ " \"pop\" : 8,\n",
+ " \"reggae\" : 9,\n",
+ " \"rock\" : 10\n",
+ "}\n",
+ "filename = 'music_genre.csv'\n",
+ "model_path = 'model.model'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Preparing data...\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " genre | \n",
+ " chroma_stft_mean | \n",
+ " chroma_stft_var | \n",
+ " rms_mean | \n",
+ " rms_var | \n",
+ " spectral_centroid_mean | \n",
+ " spectral_centroid_var | \n",
+ " spectral_bandwidth_mean | \n",
+ " spectral_bandwidth_var | \n",
+ " rolloff_mean | \n",
+ " ... | \n",
+ " mfcc16_mean | \n",
+ " mfcc16_var | \n",
+ " mfcc17_mean | \n",
+ " mfcc17_var | \n",
+ " mfcc18_mean | \n",
+ " mfcc18_var | \n",
+ " mfcc19_mean | \n",
+ " mfcc19_var | \n",
+ " mfcc20_mean | \n",
+ " mfcc20_var | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0.350088 | \n",
+ " 0.088757 | \n",
+ " 0.130228 | \n",
+ " 0.002827 | \n",
+ " 1784.165850 | \n",
+ " 1.297741e+05 | \n",
+ " 2002.449060 | \n",
+ " 85882.761315 | \n",
+ " 3805.839606 | \n",
+ " ... | \n",
+ " 0.752740 | \n",
+ " 52.420910 | \n",
+ " -1.690215 | \n",
+ " 36.524071 | \n",
+ " -0.408979 | \n",
+ " 41.597103 | \n",
+ " -2.303523 | \n",
+ " 55.062923 | \n",
+ " 1.221291 | \n",
+ " 46.936035 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0.340914 | \n",
+ " 0.094980 | \n",
+ " 0.095948 | \n",
+ " 0.002373 | \n",
+ " 1530.176679 | \n",
+ " 3.758501e+05 | \n",
+ " 2039.036516 | \n",
+ " 213843.755497 | \n",
+ " 3550.522098 | \n",
+ " ... | \n",
+ " 0.927998 | \n",
+ " 55.356403 | \n",
+ " -0.731125 | \n",
+ " 60.314529 | \n",
+ " 0.295073 | \n",
+ " 48.120598 | \n",
+ " -0.283518 | \n",
+ " 51.106190 | \n",
+ " 0.531217 | \n",
+ " 45.786282 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 0.363637 | \n",
+ " 0.085275 | \n",
+ " 0.175570 | \n",
+ " 0.002746 | \n",
+ " 1552.811865 | \n",
+ " 1.564676e+05 | \n",
+ " 1747.702312 | \n",
+ " 76254.192257 | \n",
+ " 3042.260232 | \n",
+ " ... | \n",
+ " 2.451690 | \n",
+ " 40.598766 | \n",
+ " -7.729093 | \n",
+ " 47.639427 | \n",
+ " -1.816407 | \n",
+ " 52.382141 | \n",
+ " -3.439720 | \n",
+ " 46.639660 | \n",
+ " -2.231258 | \n",
+ " 30.573025 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 0.404785 | \n",
+ " 0.093999 | \n",
+ " 0.141093 | \n",
+ " 0.006346 | \n",
+ " 1070.106615 | \n",
+ " 1.843559e+05 | \n",
+ " 1596.412872 | \n",
+ " 166441.494769 | \n",
+ " 2184.745799 | \n",
+ " ... | \n",
+ " 0.780874 | \n",
+ " 44.427753 | \n",
+ " -3.319597 | \n",
+ " 50.206673 | \n",
+ " 0.636965 | \n",
+ " 37.319130 | \n",
+ " -0.619121 | \n",
+ " 37.259739 | \n",
+ " -3.407448 | \n",
+ " 31.949339 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 0.308526 | \n",
+ " 0.087841 | \n",
+ " 0.091529 | \n",
+ " 0.002303 | \n",
+ " 1835.004266 | \n",
+ " 3.433999e+05 | \n",
+ " 1748.172116 | \n",
+ " 88445.209036 | \n",
+ " 3579.757627 | \n",
+ " ... | \n",
+ " -4.520576 | \n",
+ " 86.099236 | \n",
+ " -5.454034 | \n",
+ " 75.269707 | \n",
+ " -0.916874 | \n",
+ " 53.613918 | \n",
+ " -4.404827 | \n",
+ " 62.910812 | \n",
+ " -11.703234 | \n",
+ " 55.195160 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 1 | \n",
+ " 0.302456 | \n",
+ " 0.087532 | \n",
+ " 0.103494 | \n",
+ " 0.003981 | \n",
+ " 1831.993940 | \n",
+ " 1.030482e+06 | \n",
+ " 1729.653287 | \n",
+ " 201910.508633 | \n",
+ " 3481.517592 | \n",
+ " ... | \n",
+ " -5.576589 | \n",
+ " 72.549225 | \n",
+ " -1.838263 | \n",
+ " 68.702026 | \n",
+ " -2.783800 | \n",
+ " 42.447453 | \n",
+ " -3.047909 | \n",
+ " 39.808784 | \n",
+ " -8.109991 | \n",
+ " 46.311005 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 1 | \n",
+ " 0.291328 | \n",
+ " 0.093981 | \n",
+ " 0.141874 | \n",
+ " 0.008803 | \n",
+ " 1459.366472 | \n",
+ " 4.378594e+05 | \n",
+ " 1389.009131 | \n",
+ " 185023.239545 | \n",
+ " 2795.610963 | \n",
+ " ... | \n",
+ " -10.068051 | \n",
+ " 83.248245 | \n",
+ " -10.913176 | \n",
+ " 56.902153 | \n",
+ " -6.971336 | \n",
+ " 38.231800 | \n",
+ " -3.436505 | \n",
+ " 48.235741 | \n",
+ " -6.483466 | \n",
+ " 70.170364 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 1 | \n",
+ " 0.307955 | \n",
+ " 0.092903 | \n",
+ " 0.131822 | \n",
+ " 0.005531 | \n",
+ " 1451.667066 | \n",
+ " 4.495682e+05 | \n",
+ " 1577.270941 | \n",
+ " 168211.938804 | \n",
+ " 2954.836760 | \n",
+ " ... | \n",
+ " -8.426083 | \n",
+ " 70.438438 | \n",
+ " -10.568935 | \n",
+ " 52.090893 | \n",
+ " -10.784515 | \n",
+ " 60.461330 | \n",
+ " -4.690678 | \n",
+ " 65.547516 | \n",
+ " -8.630722 | \n",
+ " 56.401436 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 1 | \n",
+ " 0.408879 | \n",
+ " 0.086512 | \n",
+ " 0.142416 | \n",
+ " 0.001507 | \n",
+ " 1719.368948 | \n",
+ " 1.632828e+05 | \n",
+ " 2031.740381 | \n",
+ " 105542.718193 | \n",
+ " 3782.316288 | \n",
+ " ... | \n",
+ " -1.452559 | \n",
+ " 50.563751 | \n",
+ " -7.041824 | \n",
+ " 28.894934 | \n",
+ " 2.695248 | \n",
+ " 36.889568 | \n",
+ " 3.412305 | \n",
+ " 33.698597 | \n",
+ " -2.715692 | \n",
+ " 36.418430 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 1 | \n",
+ " 0.273950 | \n",
+ " 0.092316 | \n",
+ " 0.081314 | \n",
+ " 0.004347 | \n",
+ " 1817.150863 | \n",
+ " 2.982361e+05 | \n",
+ " 1973.773306 | \n",
+ " 114070.112591 | \n",
+ " 3943.490565 | \n",
+ " ... | \n",
+ " -1.179920 | \n",
+ " 59.314602 | \n",
+ " -1.916804 | \n",
+ " 58.418438 | \n",
+ " -2.292661 | \n",
+ " 83.205231 | \n",
+ " 2.881967 | \n",
+ " 77.082222 | \n",
+ " -4.235203 | \n",
+ " 91.468811 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
10 rows × 58 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " genre chroma_stft_mean chroma_stft_var rms_mean rms_var \\\n",
+ "0 1 0.350088 0.088757 0.130228 0.002827 \n",
+ "1 1 0.340914 0.094980 0.095948 0.002373 \n",
+ "2 1 0.363637 0.085275 0.175570 0.002746 \n",
+ "3 1 0.404785 0.093999 0.141093 0.006346 \n",
+ "4 1 0.308526 0.087841 0.091529 0.002303 \n",
+ "5 1 0.302456 0.087532 0.103494 0.003981 \n",
+ "6 1 0.291328 0.093981 0.141874 0.008803 \n",
+ "7 1 0.307955 0.092903 0.131822 0.005531 \n",
+ "8 1 0.408879 0.086512 0.142416 0.001507 \n",
+ "9 1 0.273950 0.092316 0.081314 0.004347 \n",
+ "\n",
+ " spectral_centroid_mean spectral_centroid_var spectral_bandwidth_mean \\\n",
+ "0 1784.165850 1.297741e+05 2002.449060 \n",
+ "1 1530.176679 3.758501e+05 2039.036516 \n",
+ "2 1552.811865 1.564676e+05 1747.702312 \n",
+ "3 1070.106615 1.843559e+05 1596.412872 \n",
+ "4 1835.004266 3.433999e+05 1748.172116 \n",
+ "5 1831.993940 1.030482e+06 1729.653287 \n",
+ "6 1459.366472 4.378594e+05 1389.009131 \n",
+ "7 1451.667066 4.495682e+05 1577.270941 \n",
+ "8 1719.368948 1.632828e+05 2031.740381 \n",
+ "9 1817.150863 2.982361e+05 1973.773306 \n",
+ "\n",
+ " spectral_bandwidth_var rolloff_mean ... mfcc16_mean mfcc16_var \\\n",
+ "0 85882.761315 3805.839606 ... 0.752740 52.420910 \n",
+ "1 213843.755497 3550.522098 ... 0.927998 55.356403 \n",
+ "2 76254.192257 3042.260232 ... 2.451690 40.598766 \n",
+ "3 166441.494769 2184.745799 ... 0.780874 44.427753 \n",
+ "4 88445.209036 3579.757627 ... -4.520576 86.099236 \n",
+ "5 201910.508633 3481.517592 ... -5.576589 72.549225 \n",
+ "6 185023.239545 2795.610963 ... -10.068051 83.248245 \n",
+ "7 168211.938804 2954.836760 ... -8.426083 70.438438 \n",
+ "8 105542.718193 3782.316288 ... -1.452559 50.563751 \n",
+ "9 114070.112591 3943.490565 ... -1.179920 59.314602 \n",
+ "\n",
+ " mfcc17_mean mfcc17_var mfcc18_mean mfcc18_var mfcc19_mean mfcc19_var \\\n",
+ "0 -1.690215 36.524071 -0.408979 41.597103 -2.303523 55.062923 \n",
+ "1 -0.731125 60.314529 0.295073 48.120598 -0.283518 51.106190 \n",
+ "2 -7.729093 47.639427 -1.816407 52.382141 -3.439720 46.639660 \n",
+ "3 -3.319597 50.206673 0.636965 37.319130 -0.619121 37.259739 \n",
+ "4 -5.454034 75.269707 -0.916874 53.613918 -4.404827 62.910812 \n",
+ "5 -1.838263 68.702026 -2.783800 42.447453 -3.047909 39.808784 \n",
+ "6 -10.913176 56.902153 -6.971336 38.231800 -3.436505 48.235741 \n",
+ "7 -10.568935 52.090893 -10.784515 60.461330 -4.690678 65.547516 \n",
+ "8 -7.041824 28.894934 2.695248 36.889568 3.412305 33.698597 \n",
+ "9 -1.916804 58.418438 -2.292661 83.205231 2.881967 77.082222 \n",
+ "\n",
+ " mfcc20_mean mfcc20_var \n",
+ "0 1.221291 46.936035 \n",
+ "1 0.531217 45.786282 \n",
+ "2 -2.231258 30.573025 \n",
+ "3 -3.407448 31.949339 \n",
+ "4 -11.703234 55.195160 \n",
+ "5 -8.109991 46.311005 \n",
+ "6 -6.483466 70.170364 \n",
+ "7 -8.630722 56.401436 \n",
+ "8 -2.715692 36.418430 \n",
+ "9 -4.235203 91.468811 \n",
+ "\n",
+ "[10 rows x 58 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "if os.path.isfile(filename):\n",
+ " print(\"Loading prepared data...\")\n",
+ " data = pd.read_csv(filename)\n",
+ "else:\n",
+ " print(\"Preparing data...\")\n",
+ " data = pd.read_csv('music_genre_raw.csv')\n",
+ " column = data[\"label\"].apply(lambda x: genre_dict[x])\n",
+ " data.insert(0, 'genre', column, 'int')\n",
+ " data = data.drop(columns=['filename', 'label', 'length'])\n",
+ " data.to_csv(filename, index=False)\n",
+ "display(data.head(10))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Podział danych na zbiory train i test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " chroma_stft_mean | \n",
+ " chroma_stft_var | \n",
+ " rms_mean | \n",
+ " rms_var | \n",
+ " spectral_centroid_mean | \n",
+ " spectral_centroid_var | \n",
+ " spectral_bandwidth_mean | \n",
+ " spectral_bandwidth_var | \n",
+ " rolloff_mean | \n",
+ " rolloff_var | \n",
+ " ... | \n",
+ " mfcc16_mean | \n",
+ " mfcc16_var | \n",
+ " mfcc17_mean | \n",
+ " mfcc17_var | \n",
+ " mfcc18_mean | \n",
+ " mfcc18_var | \n",
+ " mfcc19_mean | \n",
+ " mfcc19_var | \n",
+ " mfcc20_mean | \n",
+ " mfcc20_var | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 687 | \n",
+ " 0.516547 | \n",
+ " 0.072241 | \n",
+ " 0.267380 | \n",
+ " 0.001175 | \n",
+ " 3338.581900 | \n",
+ " 172002.893292 | \n",
+ " 2697.128636 | \n",
+ " 45771.294278 | \n",
+ " 6670.863091 | \n",
+ " 3.556853e+05 | \n",
+ " ... | \n",
+ " 8.023183 | \n",
+ " 37.339474 | \n",
+ " -8.121326 | \n",
+ " 33.968277 | \n",
+ " 4.910113 | \n",
+ " 42.063385 | \n",
+ " -2.474697 | \n",
+ " 35.162354 | \n",
+ " 3.192656 | \n",
+ " 36.478157 | \n",
+ "
\n",
+ " \n",
+ " 500 | \n",
+ " 0.344511 | \n",
+ " 0.085002 | \n",
+ " 0.046747 | \n",
+ " 0.001542 | \n",
+ " 1503.869486 | \n",
+ " 554576.511533 | \n",
+ " 1754.216082 | \n",
+ " 283554.933422 | \n",
+ " 2799.283099 | \n",
+ " 2.685679e+06 | \n",
+ " ... | \n",
+ " -1.957420 | \n",
+ " 50.311016 | \n",
+ " -1.503434 | \n",
+ " 41.141155 | \n",
+ " 0.221949 | \n",
+ " 55.707256 | \n",
+ " -1.991485 | \n",
+ " 50.006485 | \n",
+ " -3.353825 | \n",
+ " 49.906403 | \n",
+ "
\n",
+ " \n",
+ " 332 | \n",
+ " 0.368345 | \n",
+ " 0.090390 | \n",
+ " 0.111073 | \n",
+ " 0.004402 | \n",
+ " 2446.919077 | \n",
+ " 490397.099115 | \n",
+ " 2449.159840 | \n",
+ " 215375.540632 | \n",
+ " 4958.057490 | \n",
+ " 2.650020e+06 | \n",
+ " ... | \n",
+ " 0.122951 | \n",
+ " 78.892769 | \n",
+ " -1.054999 | \n",
+ " 79.877068 | \n",
+ " 4.496278 | \n",
+ " 112.834435 | \n",
+ " -0.978958 | \n",
+ " 75.059898 | \n",
+ " -5.256925 | \n",
+ " 120.275269 | \n",
+ "
\n",
+ " \n",
+ " 979 | \n",
+ " 0.360042 | \n",
+ " 0.083953 | \n",
+ " 0.116724 | \n",
+ " 0.000789 | \n",
+ " 2148.410463 | \n",
+ " 253618.158995 | \n",
+ " 2107.165355 | \n",
+ " 72155.551685 | \n",
+ " 4479.264304 | \n",
+ " 9.787046e+05 | \n",
+ " ... | \n",
+ " -0.621152 | \n",
+ " 37.060532 | \n",
+ " -13.479134 | \n",
+ " 50.848667 | \n",
+ " 3.308529 | \n",
+ " 47.726006 | \n",
+ " -3.704957 | \n",
+ " 56.781952 | \n",
+ " 1.085497 | \n",
+ " 54.243389 | \n",
+ "
\n",
+ " \n",
+ " 817 | \n",
+ " 0.425788 | \n",
+ " 0.091852 | \n",
+ " 0.139799 | \n",
+ " 0.003601 | \n",
+ " 1803.774378 | \n",
+ " 659241.158049 | \n",
+ " 1973.418903 | \n",
+ " 201432.199120 | \n",
+ " 3777.969679 | \n",
+ " 2.632339e+06 | \n",
+ " ... | \n",
+ " 3.633915 | \n",
+ " 64.068756 | \n",
+ " -2.219202 | \n",
+ " 99.249870 | \n",
+ " 5.304260 | \n",
+ " 64.088127 | \n",
+ " -6.597187 | \n",
+ " 62.661850 | \n",
+ " -2.923168 | \n",
+ " 67.490440 | \n",
+ "
\n",
+ " \n",
+ " 620 | \n",
+ " 0.495959 | \n",
+ " 0.072854 | \n",
+ " 0.117362 | \n",
+ " 0.000867 | \n",
+ " 2657.912854 | \n",
+ " 189139.438926 | \n",
+ " 2345.662472 | \n",
+ " 32730.579626 | \n",
+ " 5358.261979 | \n",
+ " 5.918222e+05 | \n",
+ " ... | \n",
+ " 5.089191 | \n",
+ " 27.937113 | \n",
+ " -10.676390 | \n",
+ " 26.519361 | \n",
+ " 3.875155 | \n",
+ " 25.613684 | \n",
+ " -4.943561 | \n",
+ " 24.334734 | \n",
+ " 3.255899 | \n",
+ " 25.199259 | \n",
+ "
\n",
+ " \n",
+ " 814 | \n",
+ " 0.395137 | \n",
+ " 0.093939 | \n",
+ " 0.114246 | \n",
+ " 0.004025 | \n",
+ " 1716.249594 | \n",
+ " 920189.339374 | \n",
+ " 2062.885827 | \n",
+ " 358557.016423 | \n",
+ " 3790.901258 | \n",
+ " 4.734865e+06 | \n",
+ " ... | \n",
+ " 3.066329 | \n",
+ " 66.090370 | \n",
+ " -4.590122 | \n",
+ " 72.595345 | \n",
+ " 4.261040 | \n",
+ " 63.185764 | \n",
+ " -2.127876 | \n",
+ " 50.693245 | \n",
+ " -3.665569 | \n",
+ " 89.750290 | \n",
+ "
\n",
+ " \n",
+ " 516 | \n",
+ " 0.249535 | \n",
+ " 0.087563 | \n",
+ " 0.060560 | \n",
+ " 0.001276 | \n",
+ " 1465.857446 | \n",
+ " 143302.098295 | \n",
+ " 1738.858902 | \n",
+ " 58868.399307 | \n",
+ " 2822.406728 | \n",
+ " 7.392007e+05 | \n",
+ " ... | \n",
+ " 2.887793 | \n",
+ " 109.811813 | \n",
+ " -0.027696 | \n",
+ " 113.660950 | \n",
+ " 2.098475 | \n",
+ " 160.025497 | \n",
+ " 1.109709 | \n",
+ " 136.810165 | \n",
+ " 2.935807 | \n",
+ " 95.914490 | \n",
+ "
\n",
+ " \n",
+ " 518 | \n",
+ " 0.353474 | \n",
+ " 0.087755 | \n",
+ " 0.052264 | \n",
+ " 0.000316 | \n",
+ " 1993.352766 | \n",
+ " 64753.479332 | \n",
+ " 2127.165109 | \n",
+ " 36027.039069 | \n",
+ " 4248.194549 | \n",
+ " 3.987029e+05 | \n",
+ " ... | \n",
+ " 12.366530 | \n",
+ " 57.230133 | \n",
+ " -1.110214 | \n",
+ " 48.080849 | \n",
+ " -0.784249 | \n",
+ " 57.033504 | \n",
+ " -2.984207 | \n",
+ " 55.737625 | \n",
+ " 0.350456 | \n",
+ " 64.126846 | \n",
+ "
\n",
+ " \n",
+ " 940 | \n",
+ " 0.416089 | \n",
+ " 0.087772 | \n",
+ " 0.142935 | \n",
+ " 0.003150 | \n",
+ " 3009.958707 | \n",
+ " 435134.775688 | \n",
+ " 2778.049758 | \n",
+ " 135548.871316 | \n",
+ " 6131.200719 | \n",
+ " 1.788624e+06 | \n",
+ " ... | \n",
+ " -5.717880 | \n",
+ " 42.315434 | \n",
+ " -3.953057 | \n",
+ " 48.761936 | \n",
+ " -3.092345 | \n",
+ " 49.514446 | \n",
+ " -2.731183 | \n",
+ " 58.219994 | \n",
+ " -0.909785 | \n",
+ " 63.111858 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
10 rows × 57 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " chroma_stft_mean chroma_stft_var rms_mean rms_var \\\n",
+ "687 0.516547 0.072241 0.267380 0.001175 \n",
+ "500 0.344511 0.085002 0.046747 0.001542 \n",
+ "332 0.368345 0.090390 0.111073 0.004402 \n",
+ "979 0.360042 0.083953 0.116724 0.000789 \n",
+ "817 0.425788 0.091852 0.139799 0.003601 \n",
+ "620 0.495959 0.072854 0.117362 0.000867 \n",
+ "814 0.395137 0.093939 0.114246 0.004025 \n",
+ "516 0.249535 0.087563 0.060560 0.001276 \n",
+ "518 0.353474 0.087755 0.052264 0.000316 \n",
+ "940 0.416089 0.087772 0.142935 0.003150 \n",
+ "\n",
+ " spectral_centroid_mean spectral_centroid_var spectral_bandwidth_mean \\\n",
+ "687 3338.581900 172002.893292 2697.128636 \n",
+ "500 1503.869486 554576.511533 1754.216082 \n",
+ "332 2446.919077 490397.099115 2449.159840 \n",
+ "979 2148.410463 253618.158995 2107.165355 \n",
+ "817 1803.774378 659241.158049 1973.418903 \n",
+ "620 2657.912854 189139.438926 2345.662472 \n",
+ "814 1716.249594 920189.339374 2062.885827 \n",
+ "516 1465.857446 143302.098295 1738.858902 \n",
+ "518 1993.352766 64753.479332 2127.165109 \n",
+ "940 3009.958707 435134.775688 2778.049758 \n",
+ "\n",
+ " spectral_bandwidth_var rolloff_mean rolloff_var ... mfcc16_mean \\\n",
+ "687 45771.294278 6670.863091 3.556853e+05 ... 8.023183 \n",
+ "500 283554.933422 2799.283099 2.685679e+06 ... -1.957420 \n",
+ "332 215375.540632 4958.057490 2.650020e+06 ... 0.122951 \n",
+ "979 72155.551685 4479.264304 9.787046e+05 ... -0.621152 \n",
+ "817 201432.199120 3777.969679 2.632339e+06 ... 3.633915 \n",
+ "620 32730.579626 5358.261979 5.918222e+05 ... 5.089191 \n",
+ "814 358557.016423 3790.901258 4.734865e+06 ... 3.066329 \n",
+ "516 58868.399307 2822.406728 7.392007e+05 ... 2.887793 \n",
+ "518 36027.039069 4248.194549 3.987029e+05 ... 12.366530 \n",
+ "940 135548.871316 6131.200719 1.788624e+06 ... -5.717880 \n",
+ "\n",
+ " mfcc16_var mfcc17_mean mfcc17_var mfcc18_mean mfcc18_var \\\n",
+ "687 37.339474 -8.121326 33.968277 4.910113 42.063385 \n",
+ "500 50.311016 -1.503434 41.141155 0.221949 55.707256 \n",
+ "332 78.892769 -1.054999 79.877068 4.496278 112.834435 \n",
+ "979 37.060532 -13.479134 50.848667 3.308529 47.726006 \n",
+ "817 64.068756 -2.219202 99.249870 5.304260 64.088127 \n",
+ "620 27.937113 -10.676390 26.519361 3.875155 25.613684 \n",
+ "814 66.090370 -4.590122 72.595345 4.261040 63.185764 \n",
+ "516 109.811813 -0.027696 113.660950 2.098475 160.025497 \n",
+ "518 57.230133 -1.110214 48.080849 -0.784249 57.033504 \n",
+ "940 42.315434 -3.953057 48.761936 -3.092345 49.514446 \n",
+ "\n",
+ " mfcc19_mean mfcc19_var mfcc20_mean mfcc20_var \n",
+ "687 -2.474697 35.162354 3.192656 36.478157 \n",
+ "500 -1.991485 50.006485 -3.353825 49.906403 \n",
+ "332 -0.978958 75.059898 -5.256925 120.275269 \n",
+ "979 -3.704957 56.781952 1.085497 54.243389 \n",
+ "817 -6.597187 62.661850 -2.923168 67.490440 \n",
+ "620 -4.943561 24.334734 3.255899 25.199259 \n",
+ "814 -2.127876 50.693245 -3.665569 89.750290 \n",
+ "516 1.109709 136.810165 2.935807 95.914490 \n",
+ "518 -2.984207 55.737625 0.350456 64.126846 \n",
+ "940 -2.731183 58.219994 -0.909785 63.111858 \n",
+ "\n",
+ "[10 rows x 57 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "X = data.drop([\"genre\"], axis=1)\n",
+ "Y = data[\"genre\"]\n",
+ "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.20, random_state = False)\n",
+ "display(X_train.head(10))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Ilość krotek dla poszczególnych gatunków z podziałem na test/train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "blues\ttest: 15\ttrain: 85\tall: 100\n",
+ "classical\ttest: 11\ttrain: 89\tall: 100\n",
+ "country\ttest: 27\ttrain: 73\tall: 100\n",
+ "disco\ttest: 22\ttrain: 78\tall: 100\n",
+ "hiphop\ttest: 23\ttrain: 77\tall: 100\n",
+ "jazz\ttest: 18\ttrain: 82\tall: 100\n",
+ "metal\ttest: 20\ttrain: 80\tall: 100\n",
+ "pop\ttest: 24\ttrain: 76\tall: 100\n",
+ "reggae\ttest: 15\ttrain: 85\tall: 100\n",
+ "rock\ttest: 25\ttrain: 75\tall: 100\n"
+ ]
+ }
+ ],
+ "source": [
+ "for key in genre_dict.keys():\n",
+ " count = len(data[data[\"genre\"]==genre_dict[key]])\n",
+ " count_train = len(X_train[Y_train==genre_dict[key]])\n",
+ " count_test = len(X_test[Y_test==genre_dict[key]])\n",
+ " print(f\"{key}\\ttest: {count_test}\\ttrain: {count_train}\\tall: {count}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Wczytywanie modelu (utworzenie nowego w przypadku jego braku)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Creating model\n",
+ "Model saved (model.model)\n"
+ ]
+ }
+ ],
+ "source": [
+ "if os.path.isfile(model_path):\n",
+ " print(\"Loading model\")\n",
+ " with open(model_path, 'rb') as file:\n",
+ " model = pickle.load(file)\n",
+ " print(f\"Model loaded ({model_path})\")\n",
+ "else:\n",
+ " print(\"Creating model\")\n",
+ " model = GaussianNB()\n",
+ " model.fit(X_train, Y_train)\n",
+ " with open(model_path, 'wb') as file:\n",
+ " pickle.dump(model, file)\n",
+ " print(f\"Model saved ({model_path})\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Ewaluacja"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Zbiór trenujący"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(Train data) Confusion matrix:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "array([[27, 1, 13, 1, 0, 2, 32, 0, 9, 0],\n",
+ " [ 0, 82, 0, 0, 0, 0, 3, 0, 3, 1],\n",
+ " [ 8, 0, 37, 11, 1, 0, 8, 1, 6, 1],\n",
+ " [ 1, 0, 3, 35, 5, 0, 23, 3, 3, 5],\n",
+ " [ 1, 0, 5, 14, 28, 0, 12, 7, 8, 2],\n",
+ " [12, 18, 2, 11, 1, 17, 14, 1, 0, 6],\n",
+ " [ 0, 0, 1, 8, 0, 0, 70, 0, 0, 1],\n",
+ " [ 0, 0, 2, 20, 4, 1, 0, 42, 5, 2],\n",
+ " [ 5, 0, 12, 4, 6, 0, 4, 7, 45, 2],\n",
+ " [ 1, 0, 10, 17, 2, 0, 29, 1, 6, 9]], dtype=int64)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(Train data) Accuracy:\n",
+ "0.49\n"
+ ]
+ }
+ ],
+ "source": [
+ "Y_train_predicted = model.predict(X_train)\n",
+ "cm = confusion_matrix(Y_train, Y_train_predicted)\n",
+ "ac = accuracy_score(Y_train, Y_train_predicted)\n",
+ "print(\"(Train data) Confusion matrix:\")\n",
+ "display(cm)\n",
+ "print(\"(Train data) Accuracy:\")\n",
+ "print(ac)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Zbiór testowy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Confusion matrix:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "array([[ 3, 0, 2, 0, 0, 0, 7, 0, 3, 0],\n",
+ " [ 1, 10, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 6, 0, 9, 4, 0, 0, 4, 0, 2, 2],\n",
+ " [ 1, 0, 1, 6, 3, 0, 4, 0, 3, 4],\n",
+ " [ 1, 0, 3, 2, 5, 0, 6, 1, 4, 1],\n",
+ " [ 1, 4, 2, 2, 0, 4, 3, 1, 1, 0],\n",
+ " [ 0, 0, 1, 4, 1, 0, 14, 0, 0, 0],\n",
+ " [ 0, 0, 1, 4, 2, 0, 1, 13, 1, 2],\n",
+ " [ 1, 0, 1, 1, 2, 0, 0, 1, 7, 2],\n",
+ " [ 1, 1, 5, 6, 1, 1, 9, 0, 0, 1]], dtype=int64)"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Accuracy:\n",
+ "0.36\n"
+ ]
+ }
+ ],
+ "source": [
+ "Y_test_predicted = model.predict(X_test)\n",
+ "cm = confusion_matrix(Y_test, Y_test_predicted)\n",
+ "ac = accuracy_score(Y_test, Y_test_predicted)\n",
+ "print(\"Confusion matrix:\")\n",
+ "display(cm)\n",
+ "print(\"Accuracy:\")\n",
+ "print(ac)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Przykładowe porównania"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Y: 10\tPredicted: 7\n",
+ "Y: 9\tPredicted: 9\n",
+ "Y: 3\tPredicted: 3\n",
+ "Y: 6\tPredicted: 6\n",
+ "Y: 7\tPredicted: 7\n",
+ "Y: 10\tPredicted: 7\n",
+ "Y: 1\tPredicted: 1\n",
+ "Y: 3\tPredicted: 10\n",
+ "Y: 4\tPredicted: 7\n",
+ "Y: 8\tPredicted: 10\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i in range(10):\n",
+ " print(f\"Y: {Y_test.to_numpy()[i]}\\tPredicted: {Y_test_predicted[i]}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Porównanie accuracy na zbiorach train/test dla różnych modeli typu NB\n",
+ "\n",
+ "| Class | Acc_Train | Acc_Test | Uwagi | \n",
+ "|---|---|---|---|\n",
+ "| GaussianNB | 0.49 | 0.36 | - |\n",
+ "| MultinomialNB | - | - | Brak obsługi ujemnego inputu |\n",
+ "| ComplementNB | - | - | Brak obsługi ujemnego inputu |\n",
+ "| BernoulliNB | 0.35125 | 0.305 | - |\n",
+ "| CategoricalNB | - | - | Brak obsługi ujemnego inputu |"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/Readme.md b/Readme.md
index 7dafd35..8785370 100644
--- a/Readme.md
+++ b/Readme.md
@@ -10,4 +10,12 @@ Zasady zaliczenia: 40 punktów podzielone następująco:
Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe). Implementacja powinna założyć, że cechy są ciągłe (do wyboru rozkład normalny i jądrowe wygładzenie). Na wejściu oczekiwany jest zbiór, który zawiera p-cech ciągłych, wektor etykiet oraz wektor prawdopodobieństw a priori dla klas. Na wyjściu otrzymujemy prognozowane etykiety oraz prawdopodobieństwa a posteriori. Dodatkową wartością może być wizualizacja obszarów decyzyjnych w przypadku dwóch cech.
-```Termin oddania na Moodle: do 31 maja. Prezentacja projektów 1 czerwca na ćwiczeniach.```
\ No newline at end of file
+```Termin oddania na Moodle: do 31 maja. Prezentacja projektów 1 czerwca na ćwiczeniach.```
+
+| Class | Acc_Train | Acc_Test | Uwagi |
+|---|---|---|---|
+| GaussianNB | 0.49 | 0.36 | - |
+| MultinomialNB | - | - | Bez ujemnego inputu |
+| ComplementNB | - | - | Bez ujemnego inputu |
+| BernoulliNB | 0.35125 | 0.305 | - |
+| CategoricalNB | - | - | Bez ujemnego inputu |
\ No newline at end of file
diff --git a/datapreparator.py b/datapreparator.py
index 408a72b..8c20e86 100644
--- a/datapreparator.py
+++ b/datapreparator.py
@@ -21,12 +21,19 @@ class DataPreparator:
def prepare_data(df: pd.DataFrame) -> pd.DataFrame:
data = deepcopy(df)
column = df["label"].apply(lambda x: DataPreparator.genre_dict[x])
- data.insert(0, 'genre', column, 'float')
+ data.insert(0, 'genre', column, 'int')
data = data.drop(columns=['filename', 'label', 'length'])
return data
def train_test_split(df: pd.DataFrame) -> typing.Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
- X = df.drop(["genre"], axis=1)
+ #X = df.drop(["genre"], axis=1)
+ X = df[["chroma_stft_mean","chroma_stft_var","rms_mean"]]
Y = df["genre"]
- return train_test_split(X, Y, test_size = 0.20, random_state = False)
\ No newline at end of file
+ return train_test_split(X, Y, test_size = 0.20, random_state = False)
+
+
+ def print_df_info(df: pd.DataFrame) -> None:
+ for key in DataPreparator.genre_dict.keys():
+ count = len(df[df["genre"]==DataPreparator.genre_dict[key]])
+ print(f"Key: {key}\tCount: {count}")
\ No newline at end of file
diff --git a/main.py b/main.py
index be5f612..616be32 100644
--- a/main.py
+++ b/main.py
@@ -13,11 +13,10 @@ else:
X_train, X_test, Y_train, Y_test = DataPreparator.train_test_split(data)
-bayes = Bayes('music_genre.model')
+bayes = Bayes('_model.model')
if(not bayes.model_exists):
bayes.train(X_train, Y_train)
-
Y_predicted = bayes.predict(X_train)
eval_result = bayes.eval(Y_train, Y_predicted)
print("Train:")
@@ -25,5 +24,11 @@ print(eval_result[1])
Y_predicted = bayes.predict(X_test)
eval_result = bayes.eval(Y_test, Y_predicted)
+
print("Test:")
print(eval_result[1])
+
+#Result preview
+# for i in range(100):
+# print(f"Expected: {Y_test.to_numpy()[i]}\tPred: {Y_predicted[i]}")
+DataPreparator.print_df_info(data)
\ No newline at end of file