DL_PROJEKT/dl_projekt2.ipynb

1162 lines
213 KiB
Plaintext
Raw Permalink Normal View History

2024-06-09 13:18:59 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### Importy"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true,
"ExecuteTime": {
"start_time": "2024-06-09T12:46:01.122296Z",
"end_time": "2024-06-09T12:46:01.143307Z"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import os\n",
"import tensorflow as tf\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, OneHotEncoder\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.naive_bayes import MultinomialNB\n",
"from sklearn.metrics import accuracy_score, confusion_matrix\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from keras.preprocessing.text import Tokenizer\n",
"from keras.models import Sequential\n",
"from keras.layers import Embedding, LSTM, Dense, Dropout\n",
"from keras.optimizers import Adam\n",
"from transformers import pipeline\n",
"from tqdm import tqdm\n",
"from keras_preprocessing.sequence import pad_sequences\n",
"from sklearn.metrics import classification_report"
]
},
{
"cell_type": "markdown",
"source": [
"### Pobiernie danych"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Looks like you're using an outdated API Version, please consider updating (server 1.6.14 / client 1.6.11)\n",
"Dataset URL: https://www.kaggle.com/datasets/shivamkushwaha/bbc-full-text-document-classification\n",
"License(s): DbCL-1.0\n",
"Downloading bbc-full-text-document-classification.zip to C:\\Users\\adamw\\PycharmProjects\\pythonProject\\dl_projekt\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" 0%| | 0.00/5.59M [00:00<?, ?B/s]\n",
" 18%|#7 | 1.00M/5.59M [00:00<00:03, 1.58MB/s]\n",
" 36%|###5 | 2.00M/5.59M [00:00<00:01, 3.16MB/s]\n",
" 72%|#######1 | 4.00M/5.59M [00:00<00:00, 6.34MB/s]\n",
"100%|##########| 5.59M/5.59M [00:00<00:00, 6.05MB/s]\n"
]
}
],
"source": [
"!kaggle datasets download -d shivamkushwaha/bbc-full-text-document-classification --unzip"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-08T22:27:32.388353Z",
"end_time": "2024-06-08T22:27:39.576352Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Sprawdzenie dostępności GPU"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num GPUs Available: 1\n",
"[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n"
]
}
],
"source": [
"# Check GPU availability\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
"physical_devices = tf.config.experimental.list_physical_devices('GPU')\n",
"print(\"Num GPUs Available: \", len(physical_devices))\n",
"print(tf.config.list_physical_devices('GPU'))\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:44:59.718075Z",
"end_time": "2024-06-09T12:44:59.731077Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Ładowanie danych"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 40,
"outputs": [],
"source": [
"datapath = 'bbc/'\n",
"directory, file, title, text, label = [], [], [], [], []\n",
"for dirname, _, filenames in os.walk(datapath):\n",
" for filename in filenames:\n",
" if filename == 'README.TXT':\n",
" continue\n",
" directory.append(dirname)\n",
" file.append(filename)\n",
" label.append(dirname.split('/')[-1])\n",
" fullpathfile = os.path.join(dirname, filename)\n",
" with open(fullpathfile, 'r', encoding=\"utf8\", errors='ignore') as infile:\n",
" intext = ''\n",
" firstline = True\n",
" for line in infile:\n",
" if firstline:\n",
" title.append(line.replace('\\n', ''))\n",
" firstline = False\n",
" else:\n",
" intext += ' ' + line.replace('\\n', '')\n",
" text.append(intext)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:53.380989Z",
"end_time": "2024-06-09T12:50:53.698989Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Konwersja na DataFrame"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 41,
"outputs": [],
"source": [
"df = pd.DataFrame(list(zip(directory, file, title, text, label)), columns=['directory', 'file', 'title', 'text', 'label'])\n",
"df = df.filter(['title', 'text', 'label'], axis=1)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:53.930080Z",
"end_time": "2024-06-09T12:50:53.986076Z"
}
}
},
{
"cell_type": "code",
"execution_count": 42,
"outputs": [
{
"data": {
"text/plain": " title \\\n0 Ad sales boost Time Warner profit \n1 Dollar gains on Greenspan speech \n2 Yukos unit buyer faces loan claim \n3 High fuel prices hit BA's profits \n4 Pernod takeover talk lifts Domecq \n\n text label \n0 Quarterly profits at US media giant TimeWarn... business \n1 The dollar has hit its highest level against... business \n2 The owners of embattled Russian oil giant Yu... business \n3 British Airways has blamed high fuel prices ... business \n4 Shares in UK drinks and food firm Allied Dom... business ",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>title</th>\n <th>text</th>\n <th>label</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>Ad sales boost Time Warner profit</td>\n <td>Quarterly profits at US media giant TimeWarn...</td>\n <td>business</td>\n </tr>\n <tr>\n <th>1</th>\n <td>Dollar gains on Greenspan speech</td>\n <td>The dollar has hit its highest level against...</td>\n <td>business</td>\n </tr>\n <tr>\n <th>2</th>\n <td>Yukos unit buyer faces loan claim</td>\n <td>The owners of embattled Russian oil giant Yu...</td>\n <td>business</td>\n </tr>\n <tr>\n <th>3</th>\n <td>High fuel prices hit BA's profits</td>\n <td>British Airways has blamed high fuel prices ...</td>\n <td>business</td>\n </tr>\n <tr>\n <th>4</th>\n <td>Pernod takeover talk lifts Domecq</td>\n <td>Shares in UK drinks and food firm Allied Dom...</td>\n <td>business</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:54.222077Z",
"end_time": "2024-06-09T12:50:54.271074Z"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "(2225, 3)"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-08T22:27:40.840352Z",
"end_time": "2024-06-08T22:27:40.908354Z"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "array(['business', 'entertainment', 'politics', 'sport', 'tech'],\n dtype=object)"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[\"label\"].unique()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-08T22:27:40.847351Z",
"end_time": "2024-06-08T22:27:40.984351Z"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "title 0\ntext 0\nlabel 0\ndtype: int64"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.isnull().sum() # Sprawdzenie brakujących wartości"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-08T22:27:40.865352Z",
"end_time": "2024-06-08T22:27:41.021352Z"
}
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "<Figure size 600x600 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhkAAAIhCAYAAADjM6hLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABNM0lEQVR4nO3de3zP9f//8fs2tpmzOY5yyobNZiwjY0M/EiVDJZHkg0yor/MpORYihxFRYhRySKUU+ShnTSYhx2mb04Y5zsb2+v3R1/vr3Tb2Zq+9N27Xy2WXy96v5+vweD3f773f9z1fr/fr5WAYhiEAAIBs5mjvAgAAwMOJkAEAAExByAAAAKYgZAAAAFMQMgAAgCkIGQAAwBSEDAAAYApCBgAAMAUhA0C24dp+AO5EyMBDaebMmfLy8sr0sa0GDBigunXr6siRIxm2nzx5UjVr1tSQIUPuexu37dy5U15eXtq5c+cDrysnRUZGqkePHpm29+/fXwEBATp27FgOVnX/YmNj5eXlpVWrVtll+/d6zWalvqZNm2bLa/JOQ4YMUdOmTbN1nXh45bN3AUBud/z4cf3www8KDw9XtWrVMpxnzpw58vPz05gxY3K4utxjxYoVmQaIw4cPa8OGDZozZ46qVq2aw5XlTR06dFCjRo3sXUY6vXv3VpcuXexdBvIIQgZwD+7u7lq/fr3Kly+f6TxvvvmmSpYsKWdn5xysLO8oXbr0PfsQ1sqWLauyZcvau4x0Hn/8cXuXgDyEwyV4ZK1bt06hoaHy9/dXw4YNNWrUKF26dMnSfuPGDY0ePVrPPfecWrRooWeeeUYLFiywWse5c+c0ePBgvfzyywoKCtKrr76q33///a7bPXXqlN555x3Vq1dPfn5+eu2113TgwIG7LnP48GH17NlTderUUZ06dRQWFqaYmBhL++1DLNu3b1fnzp3l6+urkJAQrVixQufOnVOfPn3k7++v4OBgLVy40GrdiYmJGjVqlJ566inVqlVLL774orZv3241j5eXl5YsWaLhw4erXr168vf3V79+/ZSQkCDpnyH01atXKy4uzmoI/8qVK5o4caLat2+vZ555Rq1bt9ZXX31lWW/btm315ptvWm3r6aefVkhIiNW03r1764033kjXL4cOHZKXl5d++ukny7TffvtNXl5e+uijjyzTLl68qBo1aujbb7/NtI9//PFHPf/88/L19VXbtm116NChDLfXp08f1a9fX97e3mrUqJHGjRunGzduZLmvbtuwYYNCQ0NVq1YtNWzYUOPGjdP169ct7bYe4jMMQ0OHDpWvr6+2bNmS4TyxsbEaNGiQgoKC5O3trQYNGmjQoEG6ePGiZZ79+/frtddeU926deXv76+uXbtq7969lvaMDpesWLFCrVq1ko+Pj0JCQjRz5kylpqZmuXY8vAgZeCTNnj1b77zzjmrXrq0ZM2YoLCxM69evV+fOnS0fGBMmTNAvv/yiwYMHa8GCBWrWrJkmTZqklStXSpKuXbumjh07aufOnRo4cKBmzZolFxcXdevWTdHR0Rlu98KFC3r55Zf1559/auTIkfrwww+VlpamTp06ZXqo4cSJE3r55Zd1/vx5ffDBBxo/frxiYmLUsWNHnT9/3mred955R02bNtXcuXNVuXJlvfvuu+rSpYuqVaum2bNny9fXVxMnTtS+ffskScnJyXrttde0ceNGvf3225o1a5bKli2r7t27pwsa06ZNU1pamqZOnapBgwZp06ZNmjBhgqR/QkBwcLBKlSqlZcuWKSQkRDdu3NArr7yib775Rt27d9fs2bNVt25dDR8+XB9//LEkKTg4WLt27bJ8IMXGxiomJkanT5+2hKibN29q+/bt6YKHJFWvXl3lypXTtm3bLNNu1/3bb79Zpm3dulWOjo6ZHn74+eef1bdvX3l5eSk8PFwtW7bUwIEDreY5d+6cOnXqpKSkJL3//vv65JNP1KpVKy1evFiLFi3Kcl9J0jfffKOwsDBVqVJF4eHh6tOnj9auXavevXvf98mz48aN07fffqtZs2YpKCgoXXtSUpK6dOmiY8eO6d1339WCBQvUpUsXfffdd5o2bZok6erVq+revbuKFy+umTNnatq0aUpKStIbb7yhK1euZLjduXPnauTIkWrQoIE+/vhjderUSZ988olGjhx5X/uBh4wBPIRmzJhheHp6Zvg4MTHR8PHxMUaOHGm1zO7duw1PT08jIiLCMAzDaNGihTFixAireWbNmmVs2rTJMAzDWLx4seHl5WUcOHDA0n79+nWjefPmxvLlyzOsa+rUqUatWrWM2NhYy7Tk5GSjWbNmxltvvWUYhmHs2LHD8PT0NHbs2GEYhmG88847xlNPPWVcuXLFsszFixeNunXrGu+//77VMpMnT7bMs3fvXsPT09MYOHCgZdqFCxcMT09P47PPPjMMwzCWLVtmeHp6Gnv37rXMk5aWZnTq1MkIDQ21TPP09DQ6duxotS9DhgwxateubXk8ePBgo0mTJpbHS5YsMTw9PY09e/ZYLTds2DCjVq1axsWLF43ff//dap7ly5cbzZs3N+rUqWOsXLnSMAzD2L59u+Hp6WnExMRk2KcjR440mjdvbnncsWNHo23btoaPj49x48YNwzAMY9CgQcarr76a4fKGYRihoaFGhw4drKbNnTvX8PT0tNTx66+/Gp06dbJ6HgzDMFq3bm1069Yty32VlpZmNG7c2HjjjTes5tm2bZvh6elpeX39+zX8bzExMZb6pkyZYnh7e1uWva1JkybG4MGDDcMwjAMHDhgdO3Y0/v77b6t5evbsabRo0cIwDMPyfERGRlraT548aUyaNMk4ffq0YRjWz/Ply5cNX19fY9SoUVbrXL58ueHp6WkcPnw40/rxaGAkA4+cvXv3KiUlRa1bt7aaHhAQoPLly2vXrl2SpMDAQC1fvlz/+c9/FBERoZiYGIWFhVn+o46MjFSFChVUo0YNyzoKFCig9evXq0OHDhlue/v27apRo4bKlCmjW7du6datW3J0dFTjxo2t/hu/044dO1SvXj25urpalilUqJACAgLSLePv72/53d3dXZLk5+dnmVa8eHFJsvxXun37dpUqVUre3t6WdaempqpJkybav3+/1eGj2rVrW22rbNmySkpKyrBmSdq1a5fKly9vVZMkPf/880pOTlZUVJR8fX1VvHhxy37s2LFDgYGB8vPz0+7duyVJv/zyi6pVq6YKFSpkuJ2QkBBFR0fr9OnTun79uvbt26devXopJSVFUVFRMgxDW7ZsyXAkRPrnsNiff/6pJk2aWE1v2bKl1eOgoCBFRETIxcVFR48e1caNGzVnzhxduHBBKSkpVvPera+OHz+uM2fOqGnTppY+v3Xrlp588kkVKlRIW7duzbRPM7JkyRLNmzdPrVq1ynQfJalGjRpaunSpypcvr+joaG3evFkLFizQ8ePHLfVXq1ZNJUqUUK9evTRq1Cj99NNPKlmypAYOHJjh+SG///67bty4kW5fbh9OsXVf8PDhxE88cm5/cJYsWTJdW8mSJS0fwMOHD1fZsmW1du1ajR07VmPHjpW/v79Gjx6t6tWrKzEx0fJBnlWJiYk6efKkvL29M2zP6EM7MTFR69at07p169K1lShRwupxoUKF0s1ToECBu9YTHx+faT3x8fEqWrRohutxdHS869D+pUuXVKpUqXTTb/f75cuXLQFr+/btCgsL044dOzRs2DB5eHhoxYoVkqRff/01XQC4U4MGDeTi4qJt27apZMmSyp8/v5o2bapKlSpp165dKliwoBISEjJdx6VLl2QYhiWA3Va6dGmrx7cPfyxZskTXr19XuXLl5OvrKxcXl3TrvFtfJSYmSpLee+89vffee+mWPXfuXKb7mpFDhw4pKChI3377rV577TXVrFkz03k/++wzffzxx0pMTFTJkiXl4+OjAgUKWF7zBQsW1JIlSzRnzhx9//33WrZsmVxdXdWmTRuNGDEi3YnNt/cls68u27ovePgQMvDIuf2hmZCQoCpVqli1xcfH67HHHpMkOTs7680339Sbb76pU6dOadOmTZo9e7b+53/+R99
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"label_counts = df['label'].value_counts()\n",
"\n",
"plt.figure(figsize=(6, 6))\n",
"sns.barplot(x=label_counts.index, y=label_counts.values)\n",
"plt.title(\"Ilość elementów w danej klasie\")\n",
"plt.xlabel(\"Type\")\n",
"plt.ylabel(\"Number of Articles\")\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-08T22:27:40.889352Z",
"end_time": "2024-06-08T22:27:41.292351Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Podział danych na zbiory treningowy, walidacyjny i testowy"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 43,
"outputs": [],
"source": [
"X_train_full, X_test, y_train_full, y_test = train_test_split(df['text'], df['label'], test_size=0.2, random_state=42)\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.2, random_state=42)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:57.515115Z",
"end_time": "2024-06-09T12:50:57.532113Z"
}
}
},
{
"cell_type": "code",
"execution_count": 44,
"outputs": [
{
"data": {
"text/plain": "((1424,), (356,), (445,))"
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape, X_val.shape, X_test.shape"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:57.772166Z",
"end_time": "2024-06-09T12:50:57.817245Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Tf-idf z wykorzystaniem Naive Bayes"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"tfidf_vectorizer = TfidfVectorizer(max_features=5000, stop_words='english')\n",
"X_train_tfidf = tfidf_vectorizer.fit_transform(X_train)\n",
"X_test_tfidf = tfidf_vectorizer.transform(X_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:45:08.641703Z",
"end_time": "2024-06-09T12:45:09.322735Z"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "MultinomialNB()",
"text/html": "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-to
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nb_classifier = MultinomialNB()\n",
"nb_classifier.fit(X_train_tfidf, y_train)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:45:09.298705Z",
"end_time": "2024-06-09T12:45:09.332701Z"
}
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Naive Bayes Accuracy: 0.9707865168539326\n"
]
}
],
"source": [
"y_pred = nb_classifier.predict(X_test_tfidf)\n",
"nb_accuracy = accuracy_score(y_test, y_pred)\n",
"print(f'Naive Bayes Accuracy: {nb_accuracy}')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:45:09.392705Z",
"end_time": "2024-06-09T12:45:09.449702Z"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x700 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAw0AAAJuCAYAAADy9u6gAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABaQ0lEQVR4nO3deZyN5f/H8fcZszLWsWSLSJaZMWYxsoUJiYookS/RQmUpZRv7rhDK2KVUinzRpigtUskyDE0oW5YmmsGoMWOOMfP7o5/zPcfNnWXO3LO8nj3O49G5zn3O/ZlznJnzOe/rum9bVlZWlgAAAADgKjysLgAAAABA7kbTAAAAAMAUTQMAAAAAUzQNAAAAAEzRNAAAAAAwRdMAAAAAwBRNAwAAAABTNA0AAAAATNE0AABwGc57CgCuaBqAAqZ79+6qU6eOfvrppyveHhUVpWHDhl3XYw4bNkxRUVHZUZ6p2bNnq2bNmi6XevXqqX379lq+fLnb9+8Ohw8f1tixY9WyZUvVrVtXzZs31wsvvKB9+/a5bZ9r165VixYtFBQUpNGjR2fLYx4/flw1a9bU6tWrs+XxrmVfNWvW1IoVK664zd9//63g4GDVrFlTW7Zsua7Hnzt3rl5//fV/3e5G3isAkFd5Wl0AgJx38eJFRUdHa/Xq1fL29r7px3v22WfVo0ePbKjs2lz6oJiZmamUlBR9++23GjNmjAoVKqSHH344x+q4WZ9//rmGDBmiGjVq6JlnnlGlSpV04sQJLV26VJ07d9a8efPUuHHjbN/v+PHjVbVqVb300ksqV65ctjxm2bJltWLFCt16663Z8njXwsPDQ+vWrdMjjzxiuO2LL76Q3W6/ocd99dVX1a9fv3/dLiYmRv7+/je0DwDIa2gagAKoaNGi2r9/v+bMmaOBAwfe9OPl5AdFSapXr57L9bvuukv79u3T8uXL80zTcPToUQ0dOlRNmzbVrFmzVKhQIcdtrVu3VteuXTV06FB99dVX2dLYOUtOTlbjxo3VoEGDbHtMb29vw+vibmFhYdqyZYtOnz6tUqVKudy2du1a1a5dW3v37nXb/uvUqeO2xwaA3IbpSUABVLt2bXXo0EGLFy9WfHy86bbnz5/XK6+8otatWysoKEhhYWHq1auXy4cx5+lJo0aNUuPGjXXx4kWXx5k0aZIaNGigCxcuSJJ+/fVX9enTR2FhYQoLC1Pfvn117NixG/6ZihUrJpvN5jK2YcMGPfroowoNDVVQUJDatGmjZcuWSZIyMjLUpEkTvfjii4bHat26tUaOHOm4vnLlSrVr105BQUFq3ry5Zs+e7fLznT59Wi+++KIaN26s4OBgtW/fXh988IFpvW+//bbsdrtGjhzp0jBIkp+fn4YOHapOnTrp7NmzjvFPP/1UHTt2VGhoqBo3bqzRo0e73D579my1atVK33zzje6//34FBQXpnnvucdSyZcsW1axZU5I0Z84c1axZU8ePH1f37t3VvXt3lxoubXtpak9mZqZmzpypqKgoBQUFKSoqSq+88orj9bzS9KTffvtNAwYMUOPGjVWvXj11795dsbGxjtsv3eezzz7TgAEDFBoaqsjISI0cOVKpqammz58ktWrVSh4eHvriiy9cxs+cOaMff/xR7dq1M9xn27ZteuKJJ1S/fn3HzzF79mxlZmZKkuP5iYmJcfz/pec1JiZGkZGRatKkic6ePesyPWnKlCmqWbOmfvzxR8e+Vq9erZo1a/7rvwUAyAtoGoACavjw4SpZsqSio6NNp3EMGTJEq1atUu/evbVkyRJFR0dr//79evHFF6+4WLR9+/ZKSkpymUeemZmpzz77TO3atZOXl5cOHz6sLl266NSpU3r55Zc1adIkHTt2TF27dtWpU6f+tfaMjAzH5a+//tInn3yib7/9Vv/5z38c23zzzTfq27evAgMDNXfuXM2ePVuVK1fW+PHjtWvXLnl6eqpDhw7asGGDUlJSHPeLjY3VkSNH1LFjR0nSggULNGrUKDVs2FDz589Xt27dtGjRIo0aNcpxn8GDB+vgwYMaN26cFi1apDp16mjo0KEuHyAvt2nTJtWpU+eq04MaNmyogQMHqkyZMpL+mWf/wgsvqF69enrttdfUt29frV+/Xt27d9f58+cd90tMTNT48ePVo0cPLVy4UJUqVdLQoUN18OBBBQYGOqZ2PfTQQ1qxYoXKli37r8+3JC1atEjvvfee+vbtqyVLlqhr1656/fXXNW/evCtuf+DAAXXs2FHHjx/XyJEjNX36dNlsNj322GPaunWry7ZjxoxRxYoVNXfuXD3xxBP673//e9XHdVasWDE1btxY69atcxlfv369KlSooLp167qM79u3Tz179lSJEiU0c+ZMzZs3TxEREYqJidFnn30mSYbn55KEhARt3LhRM2fOVHR0tIoXL+7y2AMHDlTVqlU1ZswY2e12JSQkaNKkSbr33nvVoUOHf/1ZACC3Y3oSUEAVL15c48eP1zPPPHPVaUp2u13nzp3TyJEj1bZtW0lSZGSkUlJS9NJLLykpKcnxofaS8PBwVaxYUZ988okaNWok6Z9vrRMTE9W+fXtJ/3yL6+fnpzfffNMxJ7xhw4Zq2bKlFi9erKFDh5rWHhgYaBiLiopy1Cj986H1wQcf1IgRIxxjoaGhatCggbZs2aKQkBB16tRJixYt0vr169WpUydJ0gcffKCqVasqLCxMf//9t+bOnatHHnnEkTw0adJEJUqU0MiRI9WrVy/VqFFDW7duVd++fdWyZUvHc1SiRAnTaUUnTpxQ7dq1TX/OS86ePat58+apc+fOLguX77jjDnXr1k2rVq1St27dJElpaWmaNGmSGjZsKEmqWrWqWrRooY0bN+rxxx93TCG65ZZbrms60datWxUUFOR4niIjI+Xn56eiRYtecfuYmBh5e3vrrbfecrzGzZs313333aepU6fqv//9r2PbZs2aOV7zhg0b6vvvv9c333xzxRTocvfee6+GDx/uMkVp7dq1Lv8WLtm3b58aNWqkadOmycPjn+/MGjdurK+++kpbtmxRu3btrvr8ZGRkaOjQoYqIiLhiHb6+vnrppZf06KOPauHChdqxY4f8/f01bty4f/0ZACAvoGkACrCoqCg98MADWrx4sVq3bm34MO7t7e04iszJkyd1+PBh/fbbb/r6668l6YoJhc1m0wMPPKB3331XY8eOlbe3t9auXauqVasqJCREkvTjjz8qMjJSvr6+ysjIkCT5+/srIiJCP/zww7/W7fyBMy0tTT/99JPmz5+vJ554Qm+++aYKFSqkJ598UpJ07tw5HT58WEePHnUcMepS3bfddpvCw8P14YcfqlOnTjp//rw+++wzPfXUU5KknTt36vz584qKinLUeel5k6Tvv/9eNWrUUIMGDTR79mzt2bNHTZs2dfkQfDWFChUyTOG6mri4ONntdt13330u4xEREapYsaK2bt3qaBok1zUft9xyiyRd03QfMw0aNNArr7yiRx99VFFRUWrevLlLsnO5rVu3qkWLFi4LhT09PdWuXTvNmTNH586du2K9l2r+/fffr6muli1batSoUfriiy/0yCOP6M8//9T27ds1evRonT592mXbDh06qEOHDkpPT9fhw4d15MgR7d27VxcvXnRMszLzb01eaGioevbsqTlz5igrK0tvvPGGIZEAgLyKpgEo4EaOHKnNmzcrOjpaq1atMty+adMmTZ48WYcOHVKRIkVUq1YtFS5cWNLVj2Xfvn17zZs3T5s2bVLTpk31+eef67HHHnPcnpycrE8//VSffvqp4b6XL2i9kuDgYJfrkZGRKlOmjAYPHqwvv/xSrVu31unTpzVmzBht2LBBNptNVapUcXxL7Fz3Qw89pOHDh+uPP/5QbGyszp0755hOkpycLEnq3bv3Fev4888/JUkzZ87U/Pnz9dlnn2n9+vXy8PBQo0aNNH78eFWsWPGK961QoYISEhKu+jNeuHBBZ8+eVenSpR3rFkqXLm3YrnTp0vr7779dxvz8/Bz/f+kb9Zs978CTTz6pIkWKaNWqVZo+fbqmTZumGjVqaOTIkbrzzjs
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"conf_matrix = confusion_matrix(y_test, y_pred)\n",
"plt.figure(figsize=(10, 7))\n",
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')\n",
"plt.xlabel('Predicted Label')\n",
"plt.ylabel('True Label')\n",
"plt.title('Naive Bayes Confusion Matrix')\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:45:09.786701Z",
"end_time": "2024-06-09T12:45:10.254699Z"
}
}
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" business 0.97 0.97 0.97 115\n",
"entertainment 0.99 0.93 0.96 72\n",
" politics 0.93 0.97 0.95 76\n",
" sport 1.00 0.99 1.00 102\n",
" tech 0.96 0.99 0.98 80\n",
"\n",
" accuracy 0.97 445\n",
" macro avg 0.97 0.97 0.97 445\n",
" weighted avg 0.97 0.97 0.97 445\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test, y_pred))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:46:10.906888Z",
"end_time": "2024-06-09T12:46:10.938882Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## LSTM"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:40:34.922075Z",
"end_time": "2024-06-09T12:40:34.991158Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" business 0.97 0.97 0.97 115\n",
"entertainment 0.99 0.93 0.96 72\n",
" politics 0.93 0.97 0.95 76\n",
" sport 1.00 0.99 1.00 102\n",
" tech 0.96 0.99 0.98 80\n",
"\n",
" accuracy 0.97 445\n",
" macro avg 0.97 0.97 0.97 445\n",
" weighted avg 0.97 0.97 0.97 445\n",
"\n"
]
}
],
"execution_count": 27
},
{
"cell_type": "markdown",
"source": [
"### Przygotowanie danych"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 31,
"outputs": [],
"source": [
"label_encoder = LabelEncoder()\n",
"integer_encoded = label_encoder.fit_transform(y_train)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:49:28.830665Z",
"end_time": "2024-06-09T12:49:28.849664Z"
}
}
},
{
"cell_type": "code",
"execution_count": 32,
"outputs": [],
"source": [
"onehot_encoder = OneHotEncoder(sparse=False)\n",
"integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)\n",
"y_train = onehot_encoder.fit_transform(integer_encoded)\n",
"integer_encoded = label_encoder.transform(y_val).reshape(len(y_val), 1)\n",
"y_val = onehot_encoder.transform(integer_encoded)\n",
"integer_encoded = label_encoder.transform(y_test).reshape(len(y_test), 1)\n",
"y_test = onehot_encoder.transform(integer_encoded)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:49:29.040664Z",
"end_time": "2024-06-09T12:49:29.090663Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Tokenizacja"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 33,
"outputs": [],
"source": [
"tokenizer = Tokenizer(num_words=5000)\n",
"tokenizer.fit_on_texts(X_train)\n",
"X_train = tokenizer.texts_to_sequences(X_train)\n",
"X_val = tokenizer.texts_to_sequences(X_val)\n",
"X_test = tokenizer.texts_to_sequences(X_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:49:29.345664Z",
"end_time": "2024-06-09T12:49:30.778251Z"
}
}
},
{
"cell_type": "code",
"execution_count": 34,
"outputs": [],
"source": [
"X_train = pad_sequences(X_train, maxlen=1000)\n",
"X_val = pad_sequences(X_val, maxlen=1000)\n",
"X_test = pad_sequences(X_test, maxlen=1000)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:49:30.748251Z",
"end_time": "2024-06-09T12:49:30.831249Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 35,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"23/23 [==============================] - 8s 147ms/step - loss: 1.5697 - accuracy: 0.2725 - val_loss: 1.2654 - val_accuracy: 0.4410\n",
"Epoch 2/15\n",
"23/23 [==============================] - 2s 108ms/step - loss: 1.1485 - accuracy: 0.4698 - val_loss: 1.1248 - val_accuracy: 0.5197\n",
"Epoch 3/15\n",
"23/23 [==============================] - 2s 109ms/step - loss: 0.8194 - accuracy: 0.6678 - val_loss: 0.6958 - val_accuracy: 0.8090\n",
"Epoch 4/15\n",
"23/23 [==============================] - 2s 107ms/step - loss: 0.3153 - accuracy: 0.9178 - val_loss: 0.4955 - val_accuracy: 0.8624\n",
"Epoch 5/15\n",
"23/23 [==============================] - 2s 104ms/step - loss: 0.1949 - accuracy: 0.9396 - val_loss: 0.4209 - val_accuracy: 0.8567\n",
"Epoch 6/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0780 - accuracy: 0.9860 - val_loss: 0.5346 - val_accuracy: 0.8567\n",
"Epoch 7/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0673 - accuracy: 0.9874 - val_loss: 0.4814 - val_accuracy: 0.8904\n",
"Epoch 8/15\n",
"23/23 [==============================] - 2s 109ms/step - loss: 0.0527 - accuracy: 0.9888 - val_loss: 0.4456 - val_accuracy: 0.8792\n",
"Epoch 9/15\n",
"23/23 [==============================] - 2s 108ms/step - loss: 0.0259 - accuracy: 0.9944 - val_loss: 0.4536 - val_accuracy: 0.8736\n",
"Epoch 10/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0147 - accuracy: 0.9993 - val_loss: 0.4479 - val_accuracy: 0.8792\n",
"Epoch 11/15\n",
"23/23 [==============================] - 2s 107ms/step - loss: 0.0179 - accuracy: 0.9958 - val_loss: 0.5509 - val_accuracy: 0.8764\n",
"Epoch 12/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0269 - accuracy: 0.9951 - val_loss: 0.4670 - val_accuracy: 0.8764\n",
"Epoch 13/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0150 - accuracy: 0.9972 - val_loss: 0.5061 - val_accuracy: 0.8652\n",
"Epoch 14/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0094 - accuracy: 0.9993 - val_loss: 0.4863 - val_accuracy: 0.8708\n",
"Epoch 15/15\n",
"23/23 [==============================] - 2s 106ms/step - loss: 0.0050 - accuracy: 1.0000 - val_loss: 0.4513 - val_accuracy: 0.8904\n"
]
},
{
"data": {
"text/plain": "<keras.callbacks.History at 0x1d303c4ba00>"
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = Sequential([\n",
" Embedding(input_dim=5000, output_dim=128, input_length=1000),\n",
" LSTM(128, return_sequences=True),\n",
" Dropout(0.2),\n",
" LSTM(64),\n",
" Dropout(0.2),\n",
" Dense(y_train.shape[1], activation='softmax')\n",
"])\n",
"optimizer = Adam(learning_rate=0.001)\n",
"model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n",
"model.fit(X_train, y_train, epochs=15, batch_size=64, validation_data=(X_val, y_val))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:49:30.824250Z",
"end_time": "2024-06-09T12:50:14.427832Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Ocena modelu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 36,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/12 [==============================] - 0s 38ms/step - loss: 0.4513 - accuracy: 0.8904\n",
"14/14 [==============================] - 1s 34ms/step - loss: 0.5677 - accuracy: 0.8629\n",
"LSTM Validation Accuracy: 0.8904494643211365, loss: 0.4512944519519806\n",
"LSTM Test Accuracy: 0.8629213571548462, loss: 0.5676819086074829\n"
]
}
],
"source": [
"val_loss, val_accuracy = model.evaluate(X_val, y_val)\n",
"test_loss, test_accuracy = model.evaluate(X_test, y_test)\n",
"print(f'LSTM Validation Accuracy: {val_accuracy}, loss: {val_loss}')\n",
"print(f'LSTM Test Accuracy: {test_accuracy}, loss: {test_loss}')\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:14.434837Z",
"end_time": "2024-06-09T12:50:15.550828Z"
}
}
},
{
"cell_type": "code",
"execution_count": 37,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14/14 [==============================] - 1s 35ms/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\adamw\\PycharmProjects\\pythonProject\\venv\\lib\\site-packages\\sklearn\\preprocessing\\_label.py:154: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
}
],
"source": [
"y_pred = model.predict(X_test)\n",
"y_pred = onehot_encoder.inverse_transform(y_pred)\n",
"y_test = onehot_encoder.inverse_transform(y_test)\n",
"y_pred = label_encoder.inverse_transform(y_pred)\n",
"y_test = label_encoder.inverse_transform(y_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:15.554828Z",
"end_time": "2024-06-09T12:50:16.967222Z"
}
}
},
{
"cell_type": "code",
"execution_count": 38,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x700 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwYAAAJuCAYAAAAKFhVXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABYDklEQVR4nO3dd3xT9f7H8XdK6YBSCmUPC7IsBUopgigoAjIcV1Dx4mAroAxBkCl7VIaAUNmgKFxBBFwgIi5Qkb0qQyjbMlqgldKRjvz+8GduIlyhmOS0Pa/n75HHj3xzevJJz03NJ+/v9xyLzWazCQAAAICpeRldAAAAAADj0RgAAAAAoDEAAAAAQGMAAAAAQDQGAAAAAERjAAAAAEA0BgAAAABEYwAAAABANAYAgDyG63ICgHvQGABwqY4dO6pjx4433e7KlSuKiopSixYtVKtWLTVo0ECdO3fWV199Zd9m27ZtqlGjxk1vZ8+eddr2hx9+uOFzxsbGOv3MzRw4cECvvfaamjZtqjp16qhFixYaOXKkzpw5c+u/kBx69913dd9996lOnTqaM2eOS/b55+9m27ZtLtnfrTyXK4/Dn6xWqyZNmqTPPvvsptvWqFFDs2fPvuV9AwAkb6MLAGA+aWlpeu6555SVlaUePXooJCREV69e1RdffKE+ffpo+PDh6ty5s8LCwrRy5Ur7z/3yyy8aN26cRo0apbCwMPt4qVKl9Ntvv0mSvLy8tGHDBjVu3Pi6512/fv0t17h8+XJNmjRJDRs21MCBA1WqVCmdOnVKixcv1saNG7V06VLddddd/+C3cL3k5GRNnjxZTZs2Vbdu3VShQgWX7PfP32PVqlVdsr9b4arj4OjixYtaunSpoqKibrrtypUrVaZMmdt6HgAwKxoDAB63YcMGxcbG6ssvv1SlSpXs4y1atFBaWppmzZql559/XgEBAapbt6798fT0dElS1apVncYd1atXT1999ZXGjBkjb2/nP3Hr169XaGioDh069Lf17dq1SxMnTtRzzz2nESNG2McbNmyoFi1aqG3btho+fLjWrFmTsxd+E0lJScrOzlaLFi109913u2y/f/09eoIrjsM/4enXCwD5AVOJAHhcQkKCJCk7O/u6x3r27KmXX35ZVqv1tvb98MMPKzExUT///LPT+OHDh3Xy5Em1adPmpvtYvHixihQpoldfffW6x4oXL66hQ4eqefPmSklJkSRlZWVp+fLleuyxx1SnTh01bdpU06ZNszcykjR06FB16dJFq1evVqtWrVSrVi09/vjj2rx5syRpzZo1atasmSRp+PDhqlGjhiSpWbNmGjp0qFMNa9ascZqGk5aWpjFjxuj+++9XrVq11Lp1ay1evNi+/Y2mEh04cEDdu3dXw4YNVa9ePfXq1UtHjx697me2bt2qbt26KTw8XPfdd5+mTp2qrKysm/4Ob+c4bNq0Sc8++6wiIiLsr2P58uWSpLNnz6p58+aSpGHDhtl/V0OHDlXnzp01evRo1atXTw8//LCysrKcphL16dNHtWvX1vHjx+3PNXv2bIWGhmr79u03fS0AYBY0BgA8rkmTJvL29lbnzp0VHR2tvXv3KiMjQ5JUp04dde/eXf7+/re176pVq6patWrasGGD0/i6devUoEEDlSxZ8m9/3maz6YcfflCjRo3+Zw0PP/ywevfurUKFCkmSRo0aZV8vMXfuXD333HNatmyZXn75ZaeFsjExMVq8eLH69eunt99+WwUKFFDfvn2VlJSkpk2bKjo6WpL00ksvOU2huplJkyZp8+bNGjJkiBYvXqzmzZtrypQpWr169Q23//nnn/XMM8/Yf3bChAk6d+6cOnTooNjYWKdtBw0apMjISM2bN0+PPvqoFi1apFWrVt20ppweh++++069e/dWWFiY5syZo9mzZ6tixYoaN26c9u3bp1KlSjn9fv78tyTt3LlT586d09tvv62BAweqQIECTvseM2aMChUqpNGjR0v64zjMmzdP3bp1U4MGDW76WgDALJhKBMDjatSooRkzZmjs2LGaPXu2Zs+eLT8/P9WvX19PPfXULX2r/3fatGmj9957z2kay/r169WrV6+b/uyVK1eUnp5+y/P7jx07po8++kgDBw5Ujx49JEn33XefSpUqpcGDB2vz5s164IEHJElXr17VmjVrdMcdd0iSChUqpOeff14///yzWrVqpdDQUEnSHXfckaOpMNu3b9d9992nRx55RNIfU54KFSqk4ODgG27/5ptvKiQkRAsWLLB/iG7cuLEeeughzZo1S2+99ZZ92/bt26t3796SpEaNGmnTpk367rvv1KFDh5vWlZPjcOzYMbVr185p6lZERIQaNmyobdu2KTw83On3U7NmTft2mZmZGjdu3P9cU1CiRAmNHj1aAwYM0KpVq7R06VJVr15dr7zyyk1fAwCYCYkBAEO0bNlS3333nRYtWqRu3bqpSpUq+umnn9S/f3/169fvH52S8q/TWPbt26cLFy6oZcuWN/3ZPz8o38p0GUn2qSh/fij/0yOPPKICBQo4Td8pXry4vSmQZP8gm5qaekvP9b80bNhQH374oV588UUtW7ZMZ86cUe/evdW0adPrtk1JSdGBAwfUpk0bp2/WAwMD9eCDD143tSYiIsLpfpkyZexTqG4mJ8fhhRde0BtvvKFr164pJiZG69ev1/z58yXpptPKgoKCbrrQ+OGHH1arVq00atQonTlzRtOmTZOPj88tvQ4AMAsaAwCGKViwoJo0aaIhQ4ZozZo1+u6779SyZUt9+eWX+u677257v5UrV1ZoaKh9Gsv69evVuHFjFS1a9KY/W7RoURUuXFhxcXH/c5uUlBQlJSVJkv3//3VqjLe3t4oVK6arV6/ax/46NclisUi68VqLnBgxYoT69++vs2fPavz48WrRooU6dOigw4cPX7ft1atXZbPZVKJEieseK1GihFO9kuTn5+d038vL65abtpwch8uXL6tv376qX7++nn76ac2ePVvJycmSbn7dgsKFC99SPe3atVN2drYqVaqkypUr39LPAICZ0BgA8LgOHTpo2LBh142XLl1aEydOlPTH1JJ/4uGHH9ZXX32ljIwMbdiw4bpv9P9O48aNtW3bNqfFw44+/PBD3XPPPfrll1/sH3Lj4+OdtsnIyNCVK1dUrFix238R/++v6cVfv7H38fHRSy+9pC+++ELffvut/VvxgQMHXrevIkWKyGKx2BeAO4qPj1dQUNA/rtfRrR6HQYMG6cCBA3r33Xe1d+9effHFFxo+fLjL6khNTVVUVJSqV6+uX3/9VUuWLHHZvgEgv6AxAOBx5cuX14YNG254obATJ05IkqpXr/6PnqNNmzZKTEzUvHnzlJSUZD+jza3o1q2bEhMTNXPmzOsei4+P15IlS1S1alWFhYXZF6+uW7fOabt169YpKytLkZGR/+h1BAQE6Pz5805ju3btsv87LS1NrVq1sn/QLVeunJ577jk98sgjN0w9ChUqpFq1aumLL75wajiuXr2q77777h/X+1e3ehx27dqlli1bqmHDhvYpPn+esenPROWvi4pz4s0339T58+c1e/ZsPf/885o1a9Z1C60BwOxYfAzA5c6fP6933333uvHq1avr3nvv1YABA7Rt2zY99dRT6tSpkyIiIuTl5aUDBw5oyZIluv/++3X//ff/oxoqVqyo2rVra/78+XrooYfsZxC6FXXr1tUrr7yimTNnKjY2Vm3btlWxYsV09OhRLV68WOnp6famoWrVqmrXrp1mzZql1NRU3X333Tp06JCio6PVsGFDNWnS5B+9jgcffFDz58/X/PnzFR4erm+++cbpFKB+fn4KCwtTdHS0ChYsqBo1aujEiRNau3atWrVqdcN9Dhw4UN27d1ePHj307LPPKiMjQwsWLJDVarUvNHaVWz0OderU0WeffaawsDCVKVNGu3fv1oIFC2SxWOxrMIoUKSJJ2rp1q6pUqaLw8PBbqmH79u1atmyZBgwYoEqVKql///766quvNHToUK1YseIfNRwAkJ/QGABwudOnT9/w6rRPPfWU7r33XlWoUEFr167V/Pnz9dl
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"conf_matrix = confusion_matrix(y_test, y_pred)\n",
"plt.figure(figsize=(10, 7))\n",
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')\n",
"plt.xlabel('Predicted Label')\n",
"plt.ylabel('True Label')\n",
"plt.title('LSTM Confusion Matrix')\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:16.939224Z",
"end_time": "2024-06-09T12:50:17.351224Z"
}
}
},
{
"cell_type": "code",
"execution_count": 39,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" business 0.86 0.86 0.86 115\n",
"entertainment 0.92 0.79 0.85 72\n",
" politics 0.72 0.93 0.81 76\n",
" sport 0.96 0.88 0.92 102\n",
" tech 0.89 0.84 0.86 80\n",
"\n",
" accuracy 0.86 445\n",
" macro avg 0.87 0.86 0.86 445\n",
" weighted avg 0.87 0.86 0.86 445\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test, y_pred))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:50:17.366224Z",
"end_time": "2024-06-09T12:50:17.553225Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Trnsformers pipeline na pre-trenowanym modelu"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:39:04.264825Z",
"end_time": "2024-06-09T12:39:04.347285Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" business 0.86 0.82 0.84 115\n",
"entertainment 0.82 0.89 0.85 72\n",
" politics 0.88 0.83 0.85 76\n",
" sport 0.94 0.97 0.96 102\n",
" tech 0.81 0.82 0.82 80\n",
"\n",
" accuracy 0.87 445\n",
" macro avg 0.86 0.87 0.86 445\n",
" weighted avg 0.87 0.87 0.87 445\n",
"\n"
]
}
],
"execution_count": 16
},
{
"cell_type": "code",
"execution_count": 45,
"outputs": [],
"source": [
"# Encode labels\n",
"label_encoder = LabelEncoder()\n",
"df['label_encoded'] = label_encoder.fit_transform(df['label'])\n",
"\n",
"# Split data\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(df['text'], df['label_encoded'], test_size=0.2,\n",
" random_state=42)\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.2, random_state=42)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:51:03.756087Z",
"end_time": "2024-06-09T12:51:03.772094Z"
}
}
},
{
"cell_type": "code",
"execution_count": 46,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"classifier = pipeline('text-classification', model='distilbert-base-uncased')\n",
"def get_predictions(texts):\n",
" predictions = []\n",
" for text in tqdm(texts, desc=\"Processing\"):\n",
" result = classifier(text, truncation=True)\n",
" predicted_label = int(result[0]['label'].split('_')[-1])\n",
" predictions.append(predicted_label)\n",
" return predictions\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:51:04.029028Z",
"end_time": "2024-06-09T12:51:04.740450Z"
}
}
},
{
"cell_type": "code",
"execution_count": 47,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing: 100%|██████████| 356/356 [02:00<00:00, 2.96it/s]\n",
"Processing: 100%|██████████| 445/445 [02:47<00:00, 2.66it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pre-trained Model Validation Accuracy: 0.20224719101123595\n",
"Pre-trained Model Test Accuracy: 0.25842696629213485\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"val_predictions = get_predictions(X_val)\n",
"test_predictions = get_predictions(X_test)\n",
"trans_val_accuracy = (val_predictions == y_val).mean()\n",
"trans_test_accuracy = (test_predictions == y_test).mean()\n",
"print(f'Pre-trained Model Validation Accuracy: {trans_val_accuracy}')\n",
"print(f'Pre-trained Model Test Accuracy: {trans_test_accuracy}')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:51:04.733447Z",
"end_time": "2024-06-09T12:55:51.983791Z"
}
}
},
{
"cell_type": "code",
"execution_count": 48,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x700 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAw0AAAJuCAYAAADy9u6gAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABb8ElEQVR4nO3dd3hT5f//8VdK6YBSRoFKQZlaSwulFBkKgogyBEEUBZUlCihDUfbeyBBEykYUBYUvAi4cCCKgIHuIzJYtii1SFFoa2ub3Bz/ySThwbDFpOp6P68p1kZPTnHdzp5p3Xvd9jsVms9kEAAAAALfg5ekCAAAAAGRvNA0AAAAATNE0AAAAADBF0wAAAADAFE0DAAAAAFM0DQAAAABM0TQAAAAAMEXTAAAAAMAUTQOQi3CtRiBv4W8eQFahaQBcpH379goNDXW6RUREqEGDBho1apQuXrzotmP//fff6t+/v3bs2OGS52vYsKEGDhzokuf6N2fOnFFoaKhWrlx508cHDhxoeF1vvLVv3z5Lav03hw8fVqtWrRQREaFmzZp5uhyX2rJli3r27Kl69eopMjJSjRs31sSJE3X+/Hm3HXPy5MmqWbOmqlWrpk8//dQlz7ly5UqFhobqzJkzLnm+jBwrNDRUx48fv+k+GzdutO+TGRn9m/+3vy8AyChvTxcA5CaVK1fWiBEj7PevXr2qX3/9VVOnTtXBgwf18ccfy2KxuPy4Bw8e1GeffaYnn3zSJc8XExOjgIAAlzzXf/XKK6+obdu29vuzZs3SgQMHFBMTY9+WXWqdOXOmzp49q5kzZ6pYsWKeLsdlpkyZogULFqhJkyYaMmSIihQposOHD2v+/Plas2aNFi9erFKlSrn0mEeOHNGCBQv09NNPq2XLlqpQoYJLnrdBgwZatmyZSpYs6ZLnywgvLy998803evnllw2PffXVV7f1nBn9my9ZsqSWLVumu+6667aOAwDX0TQALhQQEKBq1ao5bbvvvvt0+fJlvfPOO9q7d6/h8eyocuXKni7B7q677nL6wFOsWDH5+Phky9fxwoULuueee1S/fn1Pl+Iyq1ev1vz58zVo0CB16tTJvr127dqqX7++nnjiCY0bN86piXOFxMRESdJjjz2mGjVquOx5ixUrluUNXfXq1fX1118bmgar1aq1a9cqLCxMBw8edMuxs+vfCoCch+lJQBaIiIiQJJ09e1bStalMffv2Ve/evVWtWjV17txZkpSSkqJJkyapfv36ioiIUIsWLf71m8itW7eqQ4cOkqQOHTrYp+rc6hhnzpxR//79VbduXYWHh6tOnTrq37+/Lly4YH9Ox+lJ16c3fP311+rdu7eioqJUs2ZNDR06VElJSU61LF++XI899ph9WtaMGTOUlpbmtM+aNWv0+OOPq2rVqnriiSd06NCh23pNb7Ry5UpVrlxZy5cv1wMPPKCaNWsqNjZWaWlpmjdvnpo3b66qVauqWrVqatu2rX7++Wf7z86YMUOPPPKIfvjhB7Vo0UIRERFq3LixYUrMokWL1KRJE1WpUkX16tXTyJEjdenSJUlSaGiotm3bpu3btztNBzlx4oR69+6tBx54QNWqVVP79u21c+dO+3Nef33fe+89NWnSRJGRkVqxYoVmzJihJk2a6LvvvlPz5s1VpUoVtWzZUrt379aePXvUpk0bVa1aVc2bN9eWLVuc6jxy5Ii6deum6tWrq3r16urRo4dOnz5tf3zr1q0KDQ3V0qVL9dBDD6l69er66aefbvq6zps3T5UqVVLHjh0Nj5UrV079+vVTVFSUfW59SkqKZs6caX+dHn30Uc2bN0/p6en2n2vfvr2GDBmiefPmqUGDBqpSpYratm2rffv22cfj+vu4Y8eOatiwof01njFjhlMNM2bMcJra89dff+mNN97QAw88YH/NHMfxZtOTfvrpJz377LOKjo5WrVq19MYbb+j33393+pnKlStr7969euaZZ1SlShU99NBDevfdd2/6mt2oWbNmOnz4sGGK0saNG2WxWPTggw8afmb58uVq3bq1qlWrpqpVq6ply5b6+uuvJWXub95xelJaWpqeeuop1apVS3/99Zf9WAMHDlS1atV07NixDP0+APImmgYgC1z/sHDnnXfat3399dcqWLCgZs+erRdffFE2m009evTQ0qVL1blzZ82ePVtRUVHq06eP6Xzu8PBwDR8+XJI0fPhwp+lRNx4jOTlZHTp0UFxcnEaMGKF3331XHTp00OrVqzVt2jTT32HEiBEqXbq0Zs2apS5duuiTTz7R7Nmz7Y/PnTtXw4YNU506dTRnzhw999xzmj9/voYNG2bf5/vvv1fv3r0VGhqqmTNnqmnTpurXr1+mXkszaWlpWrhwocaNG6dBgwapYsWKmjJlimbNmqVnnnlGCxYs0JgxY5SYmKhXX31VycnJ9p+Nj4/X6NGj1aFDB82bN09lypTRgAEDFBcXJ0n68ssvNXnyZD333HN699131aNHD3322WcaM2aMJGnZsmWqXLmyKleurGXLlqlBgwaKjY1V69atdebMGQ0dOlRTpkyRxWJRx44dtW3bNqfaZ8yYoZdeekmTJk3SAw88IEn6448/9Oabb6p79+6aPn26/v77b/Xu3Vuvv/662rRpo5kzZ8pms6lPnz66cuWKpGvvtbZt2+r8+fOaOHGixo0bp9OnT6tdu3aG9QcxMTEaMGCAhg8frqioKMPrGR8fr0OHDqlBgwa3nFb37LPPqkuXLrJYLLLZbOrevbsWLFigNm3aaM6cOWrSpInefvttp/elJH377bdat26dhg4dqqlTpyohIUG9evVSWlqa2rRp4/SezkyK0a9fP8XFxWnUqFGaP3++KleurAEDBjg1iY4+/fRTvfDCCypVqpSmTp2qQYMGaffu3XrmmWecXq/09HS99tpratasmebNm6fq1atr0qRJ2rRp07/W9MADD6hw4cL65ptvnLZ/9dVXeuSRR5Q/f36n7UuWLNHw4cPVqFEjzZ07V1OmTJGPj4/69u2rP/74I1N/847y5cunN998U0lJSZo4caIkae3atVq1apX69+/vsilgAHInpicBLmSz2ZSammq/f/HiRW3bts3eAFxPHCQpf/78GjVqlHx8fCRd+7Zz06ZNmjZtmn0Rbb169ZScnKwpU6aoefPm8vY2/skGBASoUqVKkqRKlSrZ/32zYxw8eFB33HGHJk6caG9gateurb179xo+xN6ofv36GjBggCSpTp06+umnn/TDDz/ojTfe0D///GP/YD506FBJUt26dVWkSBENHTpUnTt31t13362ZM2eqatWqmjx5sv33k6S33noroy/xv+revbsaNGhgv//nn3+qT58+ToulfX191atXLx0+fNg+dSM5OVnjxo1TnTp1JF37Fv2hhx7Shg0bVLFiRW3btk1lypTRc889Jy8vL9WsWVMFChSwL3CvVq2afW3F9eccPXq0fHx89MEHH9gfa9CggZo3b65Jkybpk08+sdfUtGlTw/z05ORkjRgxwv5NdGxsrN566y2NGzdOTz31lCQpKSlJvXv31vHjxxUWFqaYmBj5+/vr/ffftx+zTp06atSokRYsWGAfQ+naB/4mTZrc8rW8/m17mTJlMvDKX/vmfPPmzZo6daoee+wxSdc+MPv5+Wn69Onq0KGD7r77bklSamqq3n33XXuNly9f1oABA3Tw4EFFREQ4vaczM11u27Zt6tGjhxo1aiRJqlmzpooUKWL/G3CUnp6uKVOmqG7duk7vwerVq6tZs2Z699131b9/f0nX/rZfeeUVtWnTRpIUHR2t7777Tj/88IP9fXwr3t7eatSokdMUpeTkZK1fv14zZ850Sp4k6fTp0+rSpYteeeUV+7bSpUurdevW2rlzpx577LEM/83fuOC7UqVK6tWrl9566y01atRII0eOVIMGDfTss8+a/g4AQNMAuND27dsVHh7utM3Ly0v333+/Ro8e7fRtbYUKFZw+yGzZskUWi0X169d3ajwaNmyozz//XEePHlVoaKjTNA9JN20kbnWMsLAwffTRR0pPT9eJEyd08uRJxcbG6tixY07
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"conf_matrix = confusion_matrix(y_test, test_predictions)\n",
"plt.figure(figsize=(10, 7))\n",
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')\n",
"plt.xlabel('Predicted Label')\n",
"plt.ylabel('True Label')\n",
"plt.title('Pre-trained Transformer Confusion Matrix')\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:55:51.986789Z",
"end_time": "2024-06-09T12:55:52.299790Z"
}
}
},
{
"cell_type": "code",
"execution_count": 49,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.26 0.99 0.41 115\n",
" 1 0.14 0.01 0.03 72\n",
" 2 0.00 0.00 0.00 76\n",
" 3 0.00 0.00 0.00 102\n",
" 4 0.00 0.00 0.00 80\n",
"\n",
" accuracy 0.26 445\n",
" macro avg 0.08 0.20 0.09 445\n",
"weighted avg 0.09 0.26 0.11 445\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\adamw\\PycharmProjects\\pythonProject\\venv\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1327: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"C:\\Users\\adamw\\PycharmProjects\\pythonProject\\venv\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1327: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"C:\\Users\\adamw\\PycharmProjects\\pythonProject\\venv\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1327: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
]
}
],
"source": [
"print(classification_report(y_test, test_predictions))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:55:52.303793Z",
"end_time": "2024-06-09T12:55:52.408796Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Podsumowanie"
],
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing: 17%|█▋ | 59/356 [00:12<01:01, 4.80it/s]\n",
"Exception ignored in: <function WeakKeyDictionary.__init__.<locals>.remove at 0x000002CD5CAEE290>\n",
"Traceback (most recent call last):\n",
" File \"C:\\Users\\adamw\\AppData\\Local\\Programs\\Python\\Python310\\lib\\weakref.py\", line 371, in remove\n",
" self = selfref()\n",
"KeyboardInterrupt: \n",
"\n",
"KeyboardInterrupt\n",
"\n"
]
}
],
"execution_count": 36
},
{
"cell_type": "code",
"execution_count": 50,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAAIhCAYAAACfXCH+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABWa0lEQVR4nO3deVhU5f//8dcACigugEalRi6BiogIormkoSkuJa65hKkt9nHLzCJ31MxM2xR3w/1jH03TMjVF08rc18wlt1yiFNdUlPX8/vDr/Jo46oyB4/J8XBfXxdznPue85wwc5sW5zz0WwzAMAQAAAABsuDi7AAAAAAC4GxGWAAAAAMAEYQkAAAAATBCWAAAAAMAEYQkAAAAATBCWAAAAAMAEYQkAAAAATBCWAAAAAMAEYQkAAAAATBCWANwRMTExCgwMtPmqUKGC6tSpoyFDhujChQs5vs///ve/CgwM1JIlS3J827kpMjJS77zzTq5sOyYmRjExMQ6t88477ygyMvKGyzdu3KjAwEBt3Ljx35Z3W/U56uOPP1ZQUJA2bNjg0Hq5+brcyMKFCxUYGKgTJ05IuvVrYbbOnWJPbbfDWc/nZm7nZ/6f64wdO1aBgYG5VSKAHOLm7AIAPDjKly+vwYMHWx+np6frl19+0UcffaS9e/dq7ty5slgsObKvtLQ0TZo0Sf/5z3/UpEmTHNnmnRIfHy8vLy9nl3FfOn/+vGbPnq2BAweqWrVqDq17N7wuXbt2VYcOHZxaw43kVm116tTR//73Pz300EM5vm1natWqlWrVquXsMgDcAmEJwB3j5eWlSpUq2bRVqVJFly9f1pgxY7Rz585sy2+XYRiaNm2aSpYsmSPbu5PKly/v7BLuW3ny5NHChQvl7+/v8Lp3w+vy2GOPObuEG8qt2nx8fOTj45Mr23amhx9+WA8//LCzywBwCwzDA+B0FSpUkCQlJSVZ25YuXarmzZsrNDRUNWrU0KBBg2yG6o0dO1bPPPOM4uPjFRERoZo1a+rChQvKzMzUnDlz1LJlS0VHR+vpp5/W6NGjlZqaKkkaMWKEIiIilJWVZd1Wv379FBgYqGPHjlnbpk+frsqVKystLU3vvPOOOnbsqAULFqhBgwaqUKGCmjZtqu+//97meWzevFkvvfSSqlSpogoVKigyMlJjx4617uvEiRMKDAzUsmXL1LNnT4WGhioiIkIDBgxQSkqKdTv/HO6VmpqqDz74QLVr11aFChX07LPPaunSpbc8rklJSerevbvCwsJUo0YNTZs2LVuf68fr2WefVcWKFVWnTh2b42Vmz549Cg8P1yuvvKK0tDSbZRkZGZKkxMREtWvXTqGhoapQoYKioqI0Z84ch+uLjIzUmDFjNHLkSFWvXl0VK1bUSy+9pN9++82m37p169SuXTuFhYWpatWqevPNN/XHH39Yl2dlZenjjz/Ws88+q8aNGysyMlIffvih0tPTrX0uXbqkYcOGqVatWqpUqZJatGihNWvW2NRyo2F4M2fOVNmyZXXu3Dlr27hx4xQYGKj169db2xITE1W2bFmdPHnS7uP0d/8c6paVlaXx48erTp06CgkJUdeuXU2HtNqzn1OnTik2NlZPPvmkQkND9cILL2j79u2SpJ49e+qpp56y+b2RpP79+6tBgwamtdn72m3ZskUvvPCCQkJCFBERodjYWJ09e9a63J5heIGBgZo7d67eeecdhYWFKSIiQu+++66uXr2qkSNHqlq1aqpatar69+9v87OdmpqqcePGKSoqSsHBwapfv74mT56c7Xl+/vnnatCggSpWrKgXXnjB5lx1XVJSknr37q2IiAiFhIToxRdf1J49e25YM8PwgHsDYQmA0x05ckSSVKJECUnS+PHj1bt3b1WqVEljxoxRt27d9O233yomJkZXr161rpeUlKS1a9fq448/Vt++fVWoUCENGjRII0aMUL169TRhwgS1b99es2fPVteuXWUYhurUqaMLFy5o9+7d1u1cv3dl8+bN1rYffvhBNWrUUN68eSVJu3fv1meffaaePXtq3LhxcnV1VY8ePaxvTPft26eOHTuqcOHC+vjjjzVhwgSFh4crPj5ey5Yts3m+gwcPVrFixTR+/Hi99NJL+uKLLzRhwgTTY2MYhrp166bPP/9cnTp10oQJExQaGqo33nhDixYtuuExTUlJ0QsvvKBff/1Vw4YN08CBAzV//nzrm9/rbnW8/unQoUN66aWXFBISonHjxlmPjyTNmTNHa9eu1Zo1a9StWzcFBQVp/PjxGjt2rEqUKKGhQ4dq586dDtUnXQsihw8f1ogRI/Tuu+9q9+7dio2NtS5ftGiROnfurEceeUQfffSR+vbtq+3bt+v555/XmTNnJElTpkzR3Llz1a1bNyUkJKht27b67LPPrMc9MzNTnTt31tdff60uXbpo/PjxKlWqlLp166YtW7bc8DhfV6dOHRmGYXMflNnP1ffff6/y5cvLz8/PruN0K6NGjdK4cePUsmVLxcfHq3Dhwvrwww9t+tizn8uXL6tt27bauHGj3nrrLcXHx8vd3V2dO3fWb7/9ppYtW+rkyZM29+hcvXpVy5cvV7NmzW5Y361eu82bN6tjx47y8PDQJ598on79+mnTpk3q0KGDze+6vccib968io+PV3R0tGbNmqXo6Gj98ccfGj16tGJiYvTFF19o1qxZkq79br322muaOnWqWrVqpYkTJyoqKkqffPKJzXDh2bNna/Dgwapdu7bGjx+vkJAQDRw40GbfZ8+eVZs2bfTLL79o4MCB+vDDD5WVlaX27dvr0KFDDj0PAHcZAwDugBdeeMFo3769kZ6ebv06ffq0sXTpUiMiIsJ4/vnnjaysLOP8+fNGhQoVjIEDB9qsv3nzZiMgIMCYPXu2YRiGMWbMGCMgIMDYvHmztc+BAweMgIAAY9KkSTbrLlq0yAgICDDWrFljpKamGqGhocbEiRMNwzCMo0ePGgEBAUazZs2M2NhYwzAM48qVK0ZwcLCxYMECwzAMIzY21ggICDCOHj1q3eamTZuMgIAAY/ny5YZhGMaXX35pvPzyy0ZmZqa1T2ZmphEWFmZ9LsePHzcCAgKMPn362NQXExNjNGnSxPr46aefttby448/GgEBAcY333xjs06fPn2MGjVqGOnp6abHe/bs2UZgYKBx4MABa1tSUpIRFBRkvPDCC3Yfr+vP/+mnnzaOHTtm1KxZ0+jYsaNx9epVa/8NGzYYAQEBRsuWLY2zZ88aU6ZMsdZ/3blz52z2ZU9914/F008/bWRkZFjbxo4dawQEBBhnz541MjMzjRo1ahidO3e22d/Ro0eNoKAgY+TIkYZhGEbnzp2NTp062fSZNWuWsWjRIsMwDGP16tVGQECAsXLlSuvyzMxM4/nnnzfGjh1rreWfz+vvGjRoYH2tU1JSjKCgIKNZs2Y2z6dOnTrGmDFjDMMw7DpOCxYsMAICAozjx48bhvH/XwvDMIwLFy4YQUFBxqhRo2y28dJLL9msY89+Zs2aZQQGBhp79uyx9klJSTHq169vzJs3z8jMzDSeeuop4+2337Yu/+qrr4yyZcsaf/zxR7barh+vm712hmEYzz//vNGkSRObPocPHzbKlStn/V3/5zEwExAQYLRq1cr6OCMjw6hUqZIRGRlp8zvSpEkT4z//+Y9hGIaxZs0aIyAgwFiyZInNtsaNG2cEBAQYv/76q5GVlWU8+eSTRq9evWz6DBo0yAgICDA2bNhgGIZhfPTRR0ZwcLBx4sQJa5/U1FSjbt26Ro8ePQzD+P+/J9fXuX4OA3B3454lAHfM5s2bFRQUZNPm4uKi6tWra+jQobJYLNqxY4fS0tKyTcoQHh6uYsWKadOmTWrfvr21vVy5ctbvN23aJElq3LixzbqNGzdW3759tXHjRtWuXVs1atTQTz/9pC5dumj9+vUqWbKk6tevr3nz5km6NmtVenq6ateubd2Gj4+PzT0Z1+81uHLliiQpOjpa0dHRSk1N1ZEjR3T06FHt3btXmZmZNkO9JGW7L+vhhx/W77//bnrM1q9fL4vFotq1a1uHuEnXhjh99dVXOnD
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "<Figure size 1000x600 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAAIhCAYAAACfXCH+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABR10lEQVR4nO3deVgV5f//8ReLgFsqaGSa5hKobCIoqZSI+5Jrmkuamqm55UfrY5r7bpp9AtwL07Qsc8td0bRyF5cyw9xyyVJccENBYH5/+PX8OjEWR4GD+nxcl9flueeemfcZzhnOi7nnPg6GYRgCAAAAAFhxtHcBAAAAAJATEZYAAAAAwARhCQAAAABMEJYAAAAAwARhCQAAAABMEJYAAAAAwARhCQAAAABMEJYAAAAAwARhCQCAHCCrviOe754HgPtHWAKQo3Xo0EHe3t5W/3x9fRUWFqaRI0fqypUrmb7Pzz//XN7e3lq5cmWmbzsrhYeH6913382SbXfo0EEdOnSwaZ13331X4eHh91y+c+dOeXt7a+fOnQ9a3n3VZ6sPP/xQPj4+2rFjh03r/dvP5c8//1S3bt30+++/P2iJVq5evar//ve/2rNnT6Zu935k5s86s5w5c0be3t5asmTJfa+zZMkSeXt768yZM1lVJgA7c7Z3AQDwbypUqKDhw4dbHt++fVs///yzpkyZol9++UVffPGFHBwcMmVfycnJmjlzpt588001btw4U7aZXaKiopQvXz57l/FISkhI0Pz58zV06FA9//zzNq37bz+Xbdu2acuWLQ9aYjq//PKLli9frpYtW2b6tm3l4+OjL7/8UmXLlrV3KZkqLCxMX375pZ588kl7lwIgixCWAOR4+fLlU8WKFa3aKleurBs3bigiIkIHDhxIt/x+GYahOXPmqFSpUpmyvexUoUIFe5fwyMqVK5eWLFmikiVL2rwuPxfz9/CjwN3dXe7u7vYuA0AWYhgegIeWr6+vJOns2bOWttWrV6tFixYKDAxU9erVNWzYMKuhepGRkapTp46ioqJUpUoVhYaG6sqVK0pNTdWCBQv08ssvq1mzZqpZs6YmT56spKQkSdL48eNVpUoVpaWlWbY1ePBgeXt769SpU5a2Tz/9VJUqVVJycrLeffddderUSYsXL1a9evXk6+urpk2b6rvvvrN6Hrt379brr7+uypUry9fXV+Hh4YqMjLTs6+7QnzVr1qhv374KDAxUlSpVNGTIECUmJlq28/fhXklJSXr//fdVo0YN+fr66qWXXtLq1av/9biePXtWvXv3VlBQkKpXr645c+ak63P3eL300kvy9/dXWFiY1fEyc+jQIQUHB+uNN95QcnKy1bKUlBRJUkxMjNq1a6fAwED5+vqqfv36WrBggc31hYeHKyIiQhMnTlS1atXk7++v119/Xb/99ptVv61bt6pdu3YKCgpSSEiIBgwYoD/++MOyPC0tTR9++KFeeuklNWrUSOHh4frggw90+/ZtS5/r169r9OjReuGFF1SxYkW1bNlSmzdvtqrlXsPwlixZokGDBkmSatWqZdVv0aJFatSokWXYaWRkpFJTUy3LL126pAEDBqh69ery8/NT06ZNtWzZMkl3hr117NhRktSxY0erIYr/9B6ZN2+eypUrp8uXL1v6T506Vd7e3tq+fbulLSYmRuXKldO5c+ckSb/++qu6d++uSpUqqVKlSurVq5dOnz5t6Z+RYXjh4eGKiorSuHHjFBISosDAQA0YMEA3btzQrFmz9OKLLyooKEh9+vSxqi+jr8X169erSZMm8vf3V/PmzRUXF5euhoSEBA0bNkzVqlWTn5+fWrdubfW8/45heMCjj7AE4KF14sQJSdIzzzwjSZo2bZr69++vihUrKiIiQr169dK6devUoUMH3bp1y7Le2bNntWXLFn344YcaNGiQChQooGHDhmn8+PGqXbu2pk+frvbt22v+/Pnq2bOnDMNQWFiYrly5ooMHD1q2c/feld27d1vavv/+e1WvXl0uLi6SpIMHD+qTTz5R3759NXXqVDk5OalPnz6WD6dxcXHq1KmTChYsqA8//FDTp09XcHCwoqKitGbNGqvnO3z4cBUrVkzTpk3T66+/rq+//lrTp083PTaGYahXr15auHChOnfurOnTpyswMFD/+c9/LB+ozSQmJurVV1/Vr7/+qtGjR2vo0KFatGiR9u3bZ9Xv347X3x07dkyvv/66AgICNHXqVMvxkaQFCxZoy5Yt2rx5s3r16iUfHx9NmzZNkZGReuaZZzRq1CgdOHDApvqkOx/8jx8/rvHjx2vMmDE6ePCgBg4caFm+bNkydenSRUWLFtWUKVM0aNAg7du3T6+88oouXrwoSZo9e7a++OIL9erVS9HR0Wrbtq0++eQTy3FPTU1Vly5dtGLFCnXv3l3Tpk1T6dKl1atXrwzdKxQWFqY333xT0p3hej179pQkzZw5U0OHDlXVqlU1Y8YMtW/fXrNnz9bQoUMt677zzjs6duyYRo4cqdmzZ6tChQoaOHCgduzYIR8fHw0bNszys7o7jPXf3iNhYWEyDMPqviyz1/l3332nChUqyNPTUydOnFCbNm108eJFTZw4UWPHjtXp06fVtm1by3HMqOjoaP3xxx/68MMP9eabb2rlypVq2bKlfvjhB40ePVr9+/fXxo0bFRERYVknI6/FTZs2qW/fvvL29tbUqVPVoEEDvfPOO1b7TkpK0muvvaaNGzfqP//5j6KiovTUU0+pa9eu/xiYADziDADIwV599VWjffv2xu3bty3/Lly4YKxevdqoUqWK8corrxhpaWlGQkKC4evrawwdOtRq/d27dxteXl7G/PnzDcMwjIiICMPLy8vYvXu3pc+RI0cMLy8vY+bMmVbrLlu2zPDy8jI2b95sJCUlGYGBgcaMGTMMwzCMkydPGl5eXkbz5s2NgQMHGoZhGDdv3jT8/PyMxYsXG4ZhGAMHDjS8vLyMkydPWra5a9cuw8vLy1i7dq1hGIaxdOlSo2vXrkZqaqqlT2pqqhEUFGR5LqdPnza8vLyMt99+26q+Dh06GI0bN7Y8rlmzpqWWH374wfDy8jJWrVpltc7bb79tVK9e3bh9+7bp8Z4/f77h7e1tHDlyxNJ29uxZw8fHx3j11VczfLzuPv+aNWsap06dMkJDQ41OnToZt27dsvTfsWOH4eXlZbz88svGpUuXjNmzZ1vqv+vy5ctW+8pIfXePRc2aNY2UlBRLW2RkpOHl5WVcunTJSE1NNapXr2506dLFan8nT540fHx8jIkTJxqGYRhdunQxOnfubNXns88+M5YtW2YYhmFs2rTJ8PLyMjZs2GBZnpqaarzyyitGZGSkpZa/P6+/Wrx4seHl5WWcPn3aMAzDuHr1quHv728MGzbMqt9XX31leHl5Gb/++qthGIbh6+trTJ8+3Wq/EyZMMGJjY62O744dOwzDMDL8HqlXr56lT2JiouHj42M0b97c6viGhYUZERERhmEYRv/+/Y1q1aoZ165dsyy/fPmyERQUZEyYMMG0FjM1a9Y0XnjhBavXZv369Y3AwEDj6tWrlrbu3bsbTZo0MQwj46/FFi1aGK1atbLqM3PmTMPLy8vyfv3yyy8NLy8vY//+/ZY+aWlpRvv27Y0WLVoYhvH/34t31/n7zw7Ao4crSwByvN27d8vHx8fyr1q1aurfv798fX31wQcfyMHBQfv371dycnK6SRmCg4NVrFgx7dq1y6q9fPnylv/fXdaoUSOrPo0aNZKTk5N27twpFxcXVa9eXdu2bZMkbd++XaVKlVLdunUt6+/cuVO3b99WjRo1LNtwd3dXiRIlLI+feuopSdLNmzclSc2aNdPs2bN1+/ZtxcXFad26dYqIiFBqaqrVUC9J6e75eOqpp6yG4f3V9u3b5eDgoBo1aiglJcXyLzw8XPHx8Tpy5Ijpenv27FGJEiWsbsQvWrSo1b4zcrzuunHjhjp16qT4+HiNHDlSrq6ukqQrV64oKipKkvT222+rUKFC6tq1qyZMmKAbN27o4MGDWr16tWbOnClJlmF7GanvLj8/Pzk5OVkdL+nOsT9x4oT
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"results = pd.DataFrame({\n",
" 'Model': ['Naive Bayes', 'LSTM', 'Pre-trained Transformer'],\n",
" 'Validation Accuracy': [nb_accuracy, val_accuracy, trans_val_accuracy],\n",
" 'Test Accuracy': [nb_accuracy, test_accuracy, trans_test_accuracy]\n",
"})\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"sns.barplot(x='Model', y='Validation Accuracy', data=results)\n",
"plt.title('Porównanie dokładności walidacyjnej modeli')\n",
"plt.ylabel('Validation Accuracy')\n",
"plt.show()\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"sns.barplot(x='Model', y='Test Accuracy', data=results)\n",
"plt.title('Porównanie dokładności testowej modeli')\n",
"plt.ylabel('Test Accuracy')\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:55:52.321794Z",
"end_time": "2024-06-09T12:55:52.681835Z"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}