From 0bb7eb7a58fa5f1e03f08835e2f78e87b2774ae8 Mon Sep 17 00:00:00 2001 From: Anna Nowak Date: Wed, 26 May 2021 21:08:58 +0200 Subject: [PATCH] Notebook --- .gitignore | 3 +- Bayes.ipynb | 1121 +++++++++++++++++++++++++++++++++++++++++++++ Readme.md | 10 +- datapreparator.py | 13 +- main.py | 9 +- 5 files changed, 1149 insertions(+), 7 deletions(-) create mode 100644 Bayes.ipynb diff --git a/.gitignore b/.gitignore index 46117a8..dfdfd6a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ venv* .vscode* __pycache__* music_genre.csv -music_genre.model \ No newline at end of file +*.model +.ipynb_checkpoints* \ No newline at end of file diff --git a/Bayes.ipynb b/Bayes.ipynb new file mode 100644 index 0000000..b0ba291 --- /dev/null +++ b/Bayes.ipynb @@ -0,0 +1,1121 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe)\n", + "Zasady zaliczenia: 40 punktów podzielone następująco:\n", + "- 10 pkt - prezentacja projektu\n", + "- 15 pkt - implementacja, w tym:\n", + "- 5 pkt - zgodność z tematem,\n", + "- 5 pkt - jakość kodu,\n", + "- 5 pkt - poprawność implementacji\n", + "- 10 pkt - efekt \"wow\"\n", + "- 5 pkt - aktywność wszystkich członków grupy\n", + "\n", + "Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe). Implementacja powinna założyć, że cechy są ciągłe (do wyboru rozkład normalny i jądrowe wygładzenie). Na wejściu oczekiwany jest zbiór, który zawiera p-cech ciągłych, wektor etykiet oraz wektor prawdopodobieństw a priori dla klas. Na wyjściu otrzymujemy prognozowane etykiety oraz prawdopodobieństwa a posteriori. Dodatkową wartością może być wizualizacja obszarów decyzyjnych w przypadku dwóch cech.\n", + "\n", + "```Termin oddania na Moodle: do 31 maja. Prezentacja projektów 1 czerwca na ćwiczeniach.```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pandas==1.2.4 in c:\\users\\annad\\anaconda3\\lib\\site-packages (1.2.4)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from pandas==1.2.4) (2.8.1)\n", + "Requirement already satisfied: numpy>=1.16.5 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from pandas==1.2.4) (1.20.3)\n", + "Requirement already satisfied: pytz>=2017.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from pandas==1.2.4) (2020.1)\n", + "Requirement already satisfied: six>=1.5 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from python-dateutil>=2.7.3->pandas==1.2.4) (1.15.0)\n", + "Requirement already satisfied: numpy==1.20.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (1.20.3)\n", + "Requirement already satisfied: sklearn==0.0 in c:\\users\\annad\\anaconda3\\lib\\site-packages (0.0)\n", + "Requirement already satisfied: scikit-learn in c:\\users\\annad\\appdata\\roaming\\python\\python38\\site-packages (from sklearn==0.0) (0.24.2)\n", + "Requirement already satisfied: joblib>=0.11 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (0.17.0)\n", + "Requirement already satisfied: numpy>=1.13.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (1.20.3)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (2.1.0)\n", + "Requirement already satisfied: scipy>=0.19.1 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn==0.0) (1.5.2)\n" + ] + } + ], + "source": [ + "!pip install pandas==1.2.4\n", + "!pip install numpy==1.20.3\n", + "!pip install sklearn==0.0" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "import pandas as pd\n", + "import numpy as np\n", + "import typing\n", + "import os, pickle\n", + "from sklearn.naive_bayes import GaussianNB\n", + "from sklearn.metrics import confusion_matrix, accuracy_score" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Wczytywanie i normalizacja danych" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Stałe\n", + "genre_dict = {\n", + " \"blues\" : 1,\n", + " \"classical\" : 2,\n", + " \"country\" : 3,\n", + " \"disco\" : 4,\n", + " \"hiphop\" : 5,\n", + " \"jazz\" : 6,\n", + " \"metal\" : 7,\n", + " \"pop\" : 8,\n", + " \"reggae\" : 9,\n", + " \"rock\" : 10\n", + "}\n", + "filename = 'music_genre.csv'\n", + "model_path = 'model.model'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preparing data...\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
genrechroma_stft_meanchroma_stft_varrms_meanrms_varspectral_centroid_meanspectral_centroid_varspectral_bandwidth_meanspectral_bandwidth_varrolloff_mean...mfcc16_meanmfcc16_varmfcc17_meanmfcc17_varmfcc18_meanmfcc18_varmfcc19_meanmfcc19_varmfcc20_meanmfcc20_var
010.3500880.0887570.1302280.0028271784.1658501.297741e+052002.44906085882.7613153805.839606...0.75274052.420910-1.69021536.524071-0.40897941.597103-2.30352355.0629231.22129146.936035
110.3409140.0949800.0959480.0023731530.1766793.758501e+052039.036516213843.7554973550.522098...0.92799855.356403-0.73112560.3145290.29507348.120598-0.28351851.1061900.53121745.786282
210.3636370.0852750.1755700.0027461552.8118651.564676e+051747.70231276254.1922573042.260232...2.45169040.598766-7.72909347.639427-1.81640752.382141-3.43972046.639660-2.23125830.573025
310.4047850.0939990.1410930.0063461070.1066151.843559e+051596.412872166441.4947692184.745799...0.78087444.427753-3.31959750.2066730.63696537.319130-0.61912137.259739-3.40744831.949339
410.3085260.0878410.0915290.0023031835.0042663.433999e+051748.17211688445.2090363579.757627...-4.52057686.099236-5.45403475.269707-0.91687453.613918-4.40482762.910812-11.70323455.195160
510.3024560.0875320.1034940.0039811831.9939401.030482e+061729.653287201910.5086333481.517592...-5.57658972.549225-1.83826368.702026-2.78380042.447453-3.04790939.808784-8.10999146.311005
610.2913280.0939810.1418740.0088031459.3664724.378594e+051389.009131185023.2395452795.610963...-10.06805183.248245-10.91317656.902153-6.97133638.231800-3.43650548.235741-6.48346670.170364
710.3079550.0929030.1318220.0055311451.6670664.495682e+051577.270941168211.9388042954.836760...-8.42608370.438438-10.56893552.090893-10.78451560.461330-4.69067865.547516-8.63072256.401436
810.4088790.0865120.1424160.0015071719.3689481.632828e+052031.740381105542.7181933782.316288...-1.45255950.563751-7.04182428.8949342.69524836.8895683.41230533.698597-2.71569236.418430
910.2739500.0923160.0813140.0043471817.1508632.982361e+051973.773306114070.1125913943.490565...-1.17992059.314602-1.91680458.418438-2.29266183.2052312.88196777.082222-4.23520391.468811
\n", + "

10 rows × 58 columns

\n", + "
" + ], + "text/plain": [ + " genre chroma_stft_mean chroma_stft_var rms_mean rms_var \\\n", + "0 1 0.350088 0.088757 0.130228 0.002827 \n", + "1 1 0.340914 0.094980 0.095948 0.002373 \n", + "2 1 0.363637 0.085275 0.175570 0.002746 \n", + "3 1 0.404785 0.093999 0.141093 0.006346 \n", + "4 1 0.308526 0.087841 0.091529 0.002303 \n", + "5 1 0.302456 0.087532 0.103494 0.003981 \n", + "6 1 0.291328 0.093981 0.141874 0.008803 \n", + "7 1 0.307955 0.092903 0.131822 0.005531 \n", + "8 1 0.408879 0.086512 0.142416 0.001507 \n", + "9 1 0.273950 0.092316 0.081314 0.004347 \n", + "\n", + " spectral_centroid_mean spectral_centroid_var spectral_bandwidth_mean \\\n", + "0 1784.165850 1.297741e+05 2002.449060 \n", + "1 1530.176679 3.758501e+05 2039.036516 \n", + "2 1552.811865 1.564676e+05 1747.702312 \n", + "3 1070.106615 1.843559e+05 1596.412872 \n", + "4 1835.004266 3.433999e+05 1748.172116 \n", + "5 1831.993940 1.030482e+06 1729.653287 \n", + "6 1459.366472 4.378594e+05 1389.009131 \n", + "7 1451.667066 4.495682e+05 1577.270941 \n", + "8 1719.368948 1.632828e+05 2031.740381 \n", + "9 1817.150863 2.982361e+05 1973.773306 \n", + "\n", + " spectral_bandwidth_var rolloff_mean ... mfcc16_mean mfcc16_var \\\n", + "0 85882.761315 3805.839606 ... 0.752740 52.420910 \n", + "1 213843.755497 3550.522098 ... 0.927998 55.356403 \n", + "2 76254.192257 3042.260232 ... 2.451690 40.598766 \n", + "3 166441.494769 2184.745799 ... 0.780874 44.427753 \n", + "4 88445.209036 3579.757627 ... -4.520576 86.099236 \n", + "5 201910.508633 3481.517592 ... -5.576589 72.549225 \n", + "6 185023.239545 2795.610963 ... -10.068051 83.248245 \n", + "7 168211.938804 2954.836760 ... -8.426083 70.438438 \n", + "8 105542.718193 3782.316288 ... -1.452559 50.563751 \n", + "9 114070.112591 3943.490565 ... -1.179920 59.314602 \n", + "\n", + " mfcc17_mean mfcc17_var mfcc18_mean mfcc18_var mfcc19_mean mfcc19_var \\\n", + "0 -1.690215 36.524071 -0.408979 41.597103 -2.303523 55.062923 \n", + "1 -0.731125 60.314529 0.295073 48.120598 -0.283518 51.106190 \n", + "2 -7.729093 47.639427 -1.816407 52.382141 -3.439720 46.639660 \n", + "3 -3.319597 50.206673 0.636965 37.319130 -0.619121 37.259739 \n", + "4 -5.454034 75.269707 -0.916874 53.613918 -4.404827 62.910812 \n", + "5 -1.838263 68.702026 -2.783800 42.447453 -3.047909 39.808784 \n", + "6 -10.913176 56.902153 -6.971336 38.231800 -3.436505 48.235741 \n", + "7 -10.568935 52.090893 -10.784515 60.461330 -4.690678 65.547516 \n", + "8 -7.041824 28.894934 2.695248 36.889568 3.412305 33.698597 \n", + "9 -1.916804 58.418438 -2.292661 83.205231 2.881967 77.082222 \n", + "\n", + " mfcc20_mean mfcc20_var \n", + "0 1.221291 46.936035 \n", + "1 0.531217 45.786282 \n", + "2 -2.231258 30.573025 \n", + "3 -3.407448 31.949339 \n", + "4 -11.703234 55.195160 \n", + "5 -8.109991 46.311005 \n", + "6 -6.483466 70.170364 \n", + "7 -8.630722 56.401436 \n", + "8 -2.715692 36.418430 \n", + "9 -4.235203 91.468811 \n", + "\n", + "[10 rows x 58 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "if os.path.isfile(filename):\n", + " print(\"Loading prepared data...\")\n", + " data = pd.read_csv(filename)\n", + "else:\n", + " print(\"Preparing data...\")\n", + " data = pd.read_csv('music_genre_raw.csv')\n", + " column = data[\"label\"].apply(lambda x: genre_dict[x])\n", + " data.insert(0, 'genre', column, 'int')\n", + " data = data.drop(columns=['filename', 'label', 'length'])\n", + " data.to_csv(filename, index=False)\n", + "display(data.head(10))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Podział danych na zbiory train i test" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chroma_stft_meanchroma_stft_varrms_meanrms_varspectral_centroid_meanspectral_centroid_varspectral_bandwidth_meanspectral_bandwidth_varrolloff_meanrolloff_var...mfcc16_meanmfcc16_varmfcc17_meanmfcc17_varmfcc18_meanmfcc18_varmfcc19_meanmfcc19_varmfcc20_meanmfcc20_var
6870.5165470.0722410.2673800.0011753338.581900172002.8932922697.12863645771.2942786670.8630913.556853e+05...8.02318337.339474-8.12132633.9682774.91011342.063385-2.47469735.1623543.19265636.478157
5000.3445110.0850020.0467470.0015421503.869486554576.5115331754.216082283554.9334222799.2830992.685679e+06...-1.95742050.311016-1.50343441.1411550.22194955.707256-1.99148550.006485-3.35382549.906403
3320.3683450.0903900.1110730.0044022446.919077490397.0991152449.159840215375.5406324958.0574902.650020e+06...0.12295178.892769-1.05499979.8770684.496278112.834435-0.97895875.059898-5.256925120.275269
9790.3600420.0839530.1167240.0007892148.410463253618.1589952107.16535572155.5516854479.2643049.787046e+05...-0.62115237.060532-13.47913450.8486673.30852947.726006-3.70495756.7819521.08549754.243389
8170.4257880.0918520.1397990.0036011803.774378659241.1580491973.418903201432.1991203777.9696792.632339e+06...3.63391564.068756-2.21920299.2498705.30426064.088127-6.59718762.661850-2.92316867.490440
6200.4959590.0728540.1173620.0008672657.912854189139.4389262345.66247232730.5796265358.2619795.918222e+05...5.08919127.937113-10.67639026.5193613.87515525.613684-4.94356124.3347343.25589925.199259
8140.3951370.0939390.1142460.0040251716.249594920189.3393742062.885827358557.0164233790.9012584.734865e+06...3.06632966.090370-4.59012272.5953454.26104063.185764-2.12787650.693245-3.66556989.750290
5160.2495350.0875630.0605600.0012761465.857446143302.0982951738.85890258868.3993072822.4067287.392007e+05...2.887793109.811813-0.027696113.6609502.098475160.0254971.109709136.8101652.93580795.914490
5180.3534740.0877550.0522640.0003161993.35276664753.4793322127.16510936027.0390694248.1945493.987029e+05...12.36653057.230133-1.11021448.080849-0.78424957.033504-2.98420755.7376250.35045664.126846
9400.4160890.0877720.1429350.0031503009.958707435134.7756882778.049758135548.8713166131.2007191.788624e+06...-5.71788042.315434-3.95305748.761936-3.09234549.514446-2.73118358.219994-0.90978563.111858
\n", + "

10 rows × 57 columns

\n", + "
" + ], + "text/plain": [ + " chroma_stft_mean chroma_stft_var rms_mean rms_var \\\n", + "687 0.516547 0.072241 0.267380 0.001175 \n", + "500 0.344511 0.085002 0.046747 0.001542 \n", + "332 0.368345 0.090390 0.111073 0.004402 \n", + "979 0.360042 0.083953 0.116724 0.000789 \n", + "817 0.425788 0.091852 0.139799 0.003601 \n", + "620 0.495959 0.072854 0.117362 0.000867 \n", + "814 0.395137 0.093939 0.114246 0.004025 \n", + "516 0.249535 0.087563 0.060560 0.001276 \n", + "518 0.353474 0.087755 0.052264 0.000316 \n", + "940 0.416089 0.087772 0.142935 0.003150 \n", + "\n", + " spectral_centroid_mean spectral_centroid_var spectral_bandwidth_mean \\\n", + "687 3338.581900 172002.893292 2697.128636 \n", + "500 1503.869486 554576.511533 1754.216082 \n", + "332 2446.919077 490397.099115 2449.159840 \n", + "979 2148.410463 253618.158995 2107.165355 \n", + "817 1803.774378 659241.158049 1973.418903 \n", + "620 2657.912854 189139.438926 2345.662472 \n", + "814 1716.249594 920189.339374 2062.885827 \n", + "516 1465.857446 143302.098295 1738.858902 \n", + "518 1993.352766 64753.479332 2127.165109 \n", + "940 3009.958707 435134.775688 2778.049758 \n", + "\n", + " spectral_bandwidth_var rolloff_mean rolloff_var ... mfcc16_mean \\\n", + "687 45771.294278 6670.863091 3.556853e+05 ... 8.023183 \n", + "500 283554.933422 2799.283099 2.685679e+06 ... -1.957420 \n", + "332 215375.540632 4958.057490 2.650020e+06 ... 0.122951 \n", + "979 72155.551685 4479.264304 9.787046e+05 ... -0.621152 \n", + "817 201432.199120 3777.969679 2.632339e+06 ... 3.633915 \n", + "620 32730.579626 5358.261979 5.918222e+05 ... 5.089191 \n", + "814 358557.016423 3790.901258 4.734865e+06 ... 3.066329 \n", + "516 58868.399307 2822.406728 7.392007e+05 ... 2.887793 \n", + "518 36027.039069 4248.194549 3.987029e+05 ... 12.366530 \n", + "940 135548.871316 6131.200719 1.788624e+06 ... -5.717880 \n", + "\n", + " mfcc16_var mfcc17_mean mfcc17_var mfcc18_mean mfcc18_var \\\n", + "687 37.339474 -8.121326 33.968277 4.910113 42.063385 \n", + "500 50.311016 -1.503434 41.141155 0.221949 55.707256 \n", + "332 78.892769 -1.054999 79.877068 4.496278 112.834435 \n", + "979 37.060532 -13.479134 50.848667 3.308529 47.726006 \n", + "817 64.068756 -2.219202 99.249870 5.304260 64.088127 \n", + "620 27.937113 -10.676390 26.519361 3.875155 25.613684 \n", + "814 66.090370 -4.590122 72.595345 4.261040 63.185764 \n", + "516 109.811813 -0.027696 113.660950 2.098475 160.025497 \n", + "518 57.230133 -1.110214 48.080849 -0.784249 57.033504 \n", + "940 42.315434 -3.953057 48.761936 -3.092345 49.514446 \n", + "\n", + " mfcc19_mean mfcc19_var mfcc20_mean mfcc20_var \n", + "687 -2.474697 35.162354 3.192656 36.478157 \n", + "500 -1.991485 50.006485 -3.353825 49.906403 \n", + "332 -0.978958 75.059898 -5.256925 120.275269 \n", + "979 -3.704957 56.781952 1.085497 54.243389 \n", + "817 -6.597187 62.661850 -2.923168 67.490440 \n", + "620 -4.943561 24.334734 3.255899 25.199259 \n", + "814 -2.127876 50.693245 -3.665569 89.750290 \n", + "516 1.109709 136.810165 2.935807 95.914490 \n", + "518 -2.984207 55.737625 0.350456 64.126846 \n", + "940 -2.731183 58.219994 -0.909785 63.111858 \n", + "\n", + "[10 rows x 57 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X = data.drop([\"genre\"], axis=1)\n", + "Y = data[\"genre\"]\n", + "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.20, random_state = False)\n", + "display(X_train.head(10))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ilość krotek dla poszczególnych gatunków z podziałem na test/train" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "blues\ttest: 15\ttrain: 85\tall: 100\n", + "classical\ttest: 11\ttrain: 89\tall: 100\n", + "country\ttest: 27\ttrain: 73\tall: 100\n", + "disco\ttest: 22\ttrain: 78\tall: 100\n", + "hiphop\ttest: 23\ttrain: 77\tall: 100\n", + "jazz\ttest: 18\ttrain: 82\tall: 100\n", + "metal\ttest: 20\ttrain: 80\tall: 100\n", + "pop\ttest: 24\ttrain: 76\tall: 100\n", + "reggae\ttest: 15\ttrain: 85\tall: 100\n", + "rock\ttest: 25\ttrain: 75\tall: 100\n" + ] + } + ], + "source": [ + "for key in genre_dict.keys():\n", + " count = len(data[data[\"genre\"]==genre_dict[key]])\n", + " count_train = len(X_train[Y_train==genre_dict[key]])\n", + " count_test = len(X_test[Y_test==genre_dict[key]])\n", + " print(f\"{key}\\ttest: {count_test}\\ttrain: {count_train}\\tall: {count}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wczytywanie modelu (utworzenie nowego w przypadku jego braku)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating model\n", + "Model saved (model.model)\n" + ] + } + ], + "source": [ + "if os.path.isfile(model_path):\n", + " print(\"Loading model\")\n", + " with open(model_path, 'rb') as file:\n", + " model = pickle.load(file)\n", + " print(f\"Model loaded ({model_path})\")\n", + "else:\n", + " print(\"Creating model\")\n", + " model = GaussianNB()\n", + " model.fit(X_train, Y_train)\n", + " with open(model_path, 'wb') as file:\n", + " pickle.dump(model, file)\n", + " print(f\"Model saved ({model_path})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ewaluacja" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zbiór trenujący" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Train data) Confusion matrix:\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[27, 1, 13, 1, 0, 2, 32, 0, 9, 0],\n", + " [ 0, 82, 0, 0, 0, 0, 3, 0, 3, 1],\n", + " [ 8, 0, 37, 11, 1, 0, 8, 1, 6, 1],\n", + " [ 1, 0, 3, 35, 5, 0, 23, 3, 3, 5],\n", + " [ 1, 0, 5, 14, 28, 0, 12, 7, 8, 2],\n", + " [12, 18, 2, 11, 1, 17, 14, 1, 0, 6],\n", + " [ 0, 0, 1, 8, 0, 0, 70, 0, 0, 1],\n", + " [ 0, 0, 2, 20, 4, 1, 0, 42, 5, 2],\n", + " [ 5, 0, 12, 4, 6, 0, 4, 7, 45, 2],\n", + " [ 1, 0, 10, 17, 2, 0, 29, 1, 6, 9]], dtype=int64)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Train data) Accuracy:\n", + "0.49\n" + ] + } + ], + "source": [ + "Y_train_predicted = model.predict(X_train)\n", + "cm = confusion_matrix(Y_train, Y_train_predicted)\n", + "ac = accuracy_score(Y_train, Y_train_predicted)\n", + "print(\"(Train data) Confusion matrix:\")\n", + "display(cm)\n", + "print(\"(Train data) Accuracy:\")\n", + "print(ac)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zbiór testowy" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix:\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[ 3, 0, 2, 0, 0, 0, 7, 0, 3, 0],\n", + " [ 1, 10, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [ 6, 0, 9, 4, 0, 0, 4, 0, 2, 2],\n", + " [ 1, 0, 1, 6, 3, 0, 4, 0, 3, 4],\n", + " [ 1, 0, 3, 2, 5, 0, 6, 1, 4, 1],\n", + " [ 1, 4, 2, 2, 0, 4, 3, 1, 1, 0],\n", + " [ 0, 0, 1, 4, 1, 0, 14, 0, 0, 0],\n", + " [ 0, 0, 1, 4, 2, 0, 1, 13, 1, 2],\n", + " [ 1, 0, 1, 1, 2, 0, 0, 1, 7, 2],\n", + " [ 1, 1, 5, 6, 1, 1, 9, 0, 0, 1]], dtype=int64)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy:\n", + "0.36\n" + ] + } + ], + "source": [ + "Y_test_predicted = model.predict(X_test)\n", + "cm = confusion_matrix(Y_test, Y_test_predicted)\n", + "ac = accuracy_score(Y_test, Y_test_predicted)\n", + "print(\"Confusion matrix:\")\n", + "display(cm)\n", + "print(\"Accuracy:\")\n", + "print(ac)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Przykładowe porównania" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Y: 10\tPredicted: 7\n", + "Y: 9\tPredicted: 9\n", + "Y: 3\tPredicted: 3\n", + "Y: 6\tPredicted: 6\n", + "Y: 7\tPredicted: 7\n", + "Y: 10\tPredicted: 7\n", + "Y: 1\tPredicted: 1\n", + "Y: 3\tPredicted: 10\n", + "Y: 4\tPredicted: 7\n", + "Y: 8\tPredicted: 10\n" + ] + } + ], + "source": [ + "for i in range(10):\n", + " print(f\"Y: {Y_test.to_numpy()[i]}\\tPredicted: {Y_test_predicted[i]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Porównanie accuracy na zbiorach train/test dla różnych modeli typu NB\n", + "\n", + "| Class | Acc_Train | Acc_Test | Uwagi | \n", + "|---|---|---|---|\n", + "| GaussianNB | 0.49 | 0.36 | - |\n", + "| MultinomialNB | - | - | Brak obsługi ujemnego inputu |\n", + "| ComplementNB | - | - | Brak obsługi ujemnego inputu |\n", + "| BernoulliNB | 0.35125 | 0.305 | - |\n", + "| CategoricalNB | - | - | Brak obsługi ujemnego inputu |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Readme.md b/Readme.md index 7dafd35..8785370 100644 --- a/Readme.md +++ b/Readme.md @@ -10,4 +10,12 @@ Zasady zaliczenia: 40 punktów podzielone następująco: Klasyfikacja za pomocą naiwnej metody bayesowskiej (rozkłady ciągłe). Implementacja powinna założyć, że cechy są ciągłe (do wyboru rozkład normalny i jądrowe wygładzenie). Na wejściu oczekiwany jest zbiór, który zawiera p-cech ciągłych, wektor etykiet oraz wektor prawdopodobieństw a priori dla klas. Na wyjściu otrzymujemy prognozowane etykiety oraz prawdopodobieństwa a posteriori. Dodatkową wartością może być wizualizacja obszarów decyzyjnych w przypadku dwóch cech. -```Termin oddania na Moodle: do 31 maja. Prezentacja projektów 1 czerwca na ćwiczeniach.``` \ No newline at end of file +```Termin oddania na Moodle: do 31 maja. Prezentacja projektów 1 czerwca na ćwiczeniach.``` + +| Class | Acc_Train | Acc_Test | Uwagi | +|---|---|---|---| +| GaussianNB | 0.49 | 0.36 | - | +| MultinomialNB | - | - | Bez ujemnego inputu | +| ComplementNB | - | - | Bez ujemnego inputu | +| BernoulliNB | 0.35125 | 0.305 | - | +| CategoricalNB | - | - | Bez ujemnego inputu | \ No newline at end of file diff --git a/datapreparator.py b/datapreparator.py index 408a72b..8c20e86 100644 --- a/datapreparator.py +++ b/datapreparator.py @@ -21,12 +21,19 @@ class DataPreparator: def prepare_data(df: pd.DataFrame) -> pd.DataFrame: data = deepcopy(df) column = df["label"].apply(lambda x: DataPreparator.genre_dict[x]) - data.insert(0, 'genre', column, 'float') + data.insert(0, 'genre', column, 'int') data = data.drop(columns=['filename', 'label', 'length']) return data def train_test_split(df: pd.DataFrame) -> typing.Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]: - X = df.drop(["genre"], axis=1) + #X = df.drop(["genre"], axis=1) + X = df[["chroma_stft_mean","chroma_stft_var","rms_mean"]] Y = df["genre"] - return train_test_split(X, Y, test_size = 0.20, random_state = False) \ No newline at end of file + return train_test_split(X, Y, test_size = 0.20, random_state = False) + + + def print_df_info(df: pd.DataFrame) -> None: + for key in DataPreparator.genre_dict.keys(): + count = len(df[df["genre"]==DataPreparator.genre_dict[key]]) + print(f"Key: {key}\tCount: {count}") \ No newline at end of file diff --git a/main.py b/main.py index be5f612..616be32 100644 --- a/main.py +++ b/main.py @@ -13,11 +13,10 @@ else: X_train, X_test, Y_train, Y_test = DataPreparator.train_test_split(data) -bayes = Bayes('music_genre.model') +bayes = Bayes('_model.model') if(not bayes.model_exists): bayes.train(X_train, Y_train) - Y_predicted = bayes.predict(X_train) eval_result = bayes.eval(Y_train, Y_predicted) print("Train:") @@ -25,5 +24,11 @@ print(eval_result[1]) Y_predicted = bayes.predict(X_test) eval_result = bayes.eval(Y_test, Y_predicted) + print("Test:") print(eval_result[1]) + +#Result preview +# for i in range(100): +# print(f"Expected: {Y_test.to_numpy()[i]}\tPred: {Y_predicted[i]}") +DataPreparator.print_df_info(data) \ No newline at end of file