DL_Projekt/6_Projekt.ipynb
2024-06-09 16:41:14 +02:00

1576 lines
58 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "5K66MlDpZDnE"
},
"source": [
"## Analiza sentymentu w opiniach z Twitter'a\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OaqaYfQFZDnH"
},
"source": [
"### Download dataset and prepare data\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "23k83t7RNCJa"
},
"source": [
"#### Installation of packages\n"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "I5pSpk6PNCJb",
"outputId": "3f30ecd9-104a-496a-fd52-447f4d64e814"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.4)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.1)\n",
"Requirement already satisfied: numpy>=1.21.0 in /usr/local/lib/python3.10/dist-packages (from pandas) (1.25.2)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.2.2)\n",
"Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.25.2)\n",
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.11.4)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n",
"Requirement already satisfied: emoji in /usr/local/lib/python3.10/dist-packages (2.12.1)\n",
"Requirement already satisfied: typing-extensions>=4.7.0 in /usr/local/lib/python3.10/dist-packages (from emoji) (4.12.1)\n",
"Requirement already satisfied: gensim in /usr/local/lib/python3.10/dist-packages (4.3.2)\n",
"Requirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.10/dist-packages (from gensim) (1.25.2)\n",
"Requirement already satisfied: scipy>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from gensim) (1.11.4)\n",
"Requirement already satisfied: smart-open>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from gensim) (6.4.0)\n"
]
}
],
"source": [
"%pip install pandas\n",
"%pip install scikit-learn\n",
"%pip install emoji\n",
"%pip install gensim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FA_aZGAkNCJd"
},
"source": [
"#### Importing libraries\n"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {
"id": "yQvOCaX2NCJd"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"import emoji\n",
"from gensim.utils import simple_preprocess"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gp8ITdbPNCJe"
},
"source": [
"#### Download the dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DlcNiu4UNCJe",
"outputId": "015b3ad1-6b9d-4845-dd98-0b0c085b12c9"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Dataset URL: https://www.kaggle.com/datasets/jp797498e/twitter-entity-sentiment-analysis\n",
"License(s): CC0-1.0\n",
"twitter-entity-sentiment-analysis.zip: Skipping, found more recently modified local copy (use --force to force download)\n"
]
}
],
"source": [
"!kaggle datasets download -d jp797498e/twitter-entity-sentiment-analysis"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yU4XFDrUNCJf"
},
"source": [
"#### Unzip the dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "G2gaml-MNCJf",
"outputId": "e327c071-a0cd-480f-92d3-66388fd4dfcb"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Archive: twitter-entity-sentiment-analysis.zip\n",
" inflating: twitter_training.csv \n",
" inflating: twitter_validation.csv \n"
]
}
],
"source": [
"!unzip -o twitter-entity-sentiment-analysis.zip"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBO6YHwyNCJg"
},
"source": [
"#### Load the dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {
"id": "9KlnXJTtNCJg"
},
"outputs": [],
"source": [
"cols = [\"tweetid\", \"entity\", \"sentiment\", \"content\"]\n",
"twitter_training = pd.read_csv(\"twitter_training.csv\", names=cols)\n",
"twitter_validation = pd.read_csv(\"twitter_validation.csv\", names=cols)\n",
"dataset = pd.concat([twitter_training, twitter_validation])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XIslo9YQNCJg"
},
"source": [
"#### Info about the dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rnh5-0SZNCJh",
"outputId": "99319b5c-f4e2-4aee-e963-13e8d2e938ee"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 75682 entries, 0 to 999\n",
"Data columns (total 4 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 tweetid 75682 non-null int64 \n",
" 1 entity 75682 non-null object\n",
" 2 sentiment 75682 non-null object\n",
" 3 content 74996 non-null object\n",
"dtypes: int64(1), object(3)\n",
"memory usage: 2.9+ MB\n"
]
}
],
"source": [
"dataset.info()"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rpHNMU57NCJh",
"outputId": "576fba81-c5fc-47ee-aae9-4f1734081e97"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(75682, 4)"
]
},
"metadata": {},
"execution_count": 104
}
],
"source": [
"dataset.shape"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eylMuu0GNCJj",
"outputId": "d04a8e0a-42ac-4f70-f277-5b9300e97016"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"sentiment\n",
"Negative 22808\n",
"Positive 21109\n",
"Neutral 18603\n",
"Irrelevant 13162\n",
"Name: count, dtype: int64"
]
},
"metadata": {},
"execution_count": 105
}
],
"source": [
"dataset[\"sentiment\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fm7H57JINCJj",
"outputId": "6af989a7-c3e7-4666-afef-c2859265d027"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tweetid 0\n",
"entity 0\n",
"sentiment 0\n",
"content 686\n",
"dtype: int64"
]
},
"metadata": {},
"execution_count": 106
}
],
"source": [
"dataset.isna().sum()"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AF_ZNH6pNCJk",
"outputId": "f3191e1e-1176-4c08-9c3e-31eee38020d8"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"3217"
]
},
"metadata": {},
"execution_count": 107
}
],
"source": [
"dataset.duplicated().sum()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LeIs8ceHNCJl"
},
"source": [
"#### Prepare the dataset\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GJfxQkbWNCJl"
},
"source": [
"##### Drop tweetid and entity columns\n"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {
"id": "X3GwAqSQNCJl"
},
"outputs": [],
"source": [
"dataset = dataset.drop(columns=[\"tweetid\", \"entity\"], axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WpBDzbx6NCJm"
},
"source": [
"##### Drop null values\n"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {
"id": "ixlmP6cwNCJm"
},
"outputs": [],
"source": [
"dataset.dropna(inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z_0UNES2NCJm"
},
"source": [
"##### Remove emojis\n"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {
"id": "I9Mr8rAQNCJm"
},
"outputs": [],
"source": [
"dataset[\"content\"] = dataset[\"content\"].apply(\n",
" lambda x: emoji.replace_emoji(x, replace=\"\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wLptSqrzNCJm"
},
"source": [
"##### Simple Preprocess\n"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {
"id": "gw8HC9XBNCJm"
},
"outputs": [],
"source": [
"dataset[\"content\"] = dataset[\"content\"].apply(lambda x: \" \".join(simple_preprocess(x)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vayfwdnkNCJm"
},
"source": [
"##### Drop null values\n"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {
"id": "6r1_Hk1JNCJn"
},
"outputs": [],
"source": [
"dataset.dropna(inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k2aDiHxrNCJn"
},
"source": [
"##### Drop duplicates\n"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {
"id": "56YaoLvjNCJn"
},
"outputs": [],
"source": [
"dataset.drop_duplicates(inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "exgWXEmNNCJn"
},
"source": [
"#### Info about the dataset after cleaning\n"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1oaFMRqANCJn",
"outputId": "05560ac6-9dd5-4397-8730-59239af28fc6"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 65839 entries, 0 to 991\n",
"Data columns (total 2 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 sentiment 65839 non-null object\n",
" 1 content 65839 non-null object\n",
"dtypes: object(2)\n",
"memory usage: 1.5+ MB\n"
]
}
],
"source": [
"dataset.info()"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "92f8IAAINCJo",
"outputId": "383826e9-6f8b-4e66-c7f9-2efacb8a5c96"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(65839, 2)"
]
},
"metadata": {},
"execution_count": 115
}
],
"source": [
"dataset.shape"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "F7a05XcCNCJo",
"outputId": "4b189aa3-8df7-44be-aa20-9066d4cde04a"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"sentiment\n",
"Negative 20147\n",
"Positive 17868\n",
"Neutral 16193\n",
"Irrelevant 11631\n",
"Name: count, dtype: int64"
]
},
"metadata": {},
"execution_count": 116
}
],
"source": [
"dataset[\"sentiment\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GG3Qgk44NCJo",
"outputId": "4959a695-513d-47e0-cbe9-b05d149478cf"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"sentiment 0\n",
"content 0\n",
"dtype: int64"
]
},
"metadata": {},
"execution_count": 117
}
],
"source": [
"dataset.isna().sum()"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u5g9cVa1NCJo",
"outputId": "214e37d0-71b0-4616-fb37-2b47e365ee14"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0"
]
},
"metadata": {},
"execution_count": 118
}
],
"source": [
"dataset.duplicated().sum()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eVZuSi-SNCJo"
},
"source": [
"#### Split the dataset into training and testing sets\n"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {
"id": "BCy_q1GHNCJp"
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(\n",
" dataset[\"content\"], dataset[\"sentiment\"], test_size=0.2, random_state=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-qSKofcjNCJt",
"outputId": "d2e09e82-4174-4e87-a0be-5e9158668733"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((52671,), (13168,), (52671,), (13168,))"
]
},
"metadata": {},
"execution_count": 120
}
],
"source": [
"X_train.shape, X_test.shape, y_train.shape, y_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UmMEl5AYNCJt"
},
"source": [
"### TD-IDF - Logistic Regression\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2_scDhXqZDnJ"
},
"source": [
"#### Importing libraries\n"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {
"id": "ugm_fVSiZDnK"
},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import classification_report"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X8DY5eStNCJu"
},
"source": [
"#### Text Vectorization Using TF-IDF\n"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {
"id": "IBAy8zjcNCJu"
},
"outputs": [],
"source": [
"vectorizer = TfidfVectorizer()\n",
"X_train_tfidf = vectorizer.fit_transform(X_train)\n",
"X_test_tfidf = vectorizer.transform(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rS5pptZINCJu"
},
"source": [
"#### Training a Logistic Regression model\n"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "m3tmiTWVNCJu",
"outputId": "e9372b78-9ea9-4a4a-8289-f3e9ee6ba511"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LogisticRegression(max_iter=1000)"
],
"text/html": [
"<style>#sk-container-id-3 {color: black;background-color: white;}#sk-container-id-3 pre{padding: 0;}#sk-container-id-3 div.sk-toggleable {background-color: white;}#sk-container-id-3 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-3 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-3 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-3 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-3 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-3 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-3 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-3 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-3 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-3 div.sk-item {position: relative;z-index: 1;}#sk-container-id-3 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-3 div.sk-item::before, #sk-container-id-3 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-3 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-3 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-3 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-3 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-3 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-3 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-3 div.sk-label-container {text-align: center;}#sk-container-id-3 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-3 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression(max_iter=1000)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(max_iter=1000)</pre></div></div></div></div></div>"
]
},
"metadata": {},
"execution_count": 123
}
],
"source": [
"model = LogisticRegression(solver=\"lbfgs\", penalty=\"l2\", max_iter=1000)\n",
"model.fit(X_train_tfidf, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GiY_P6PKNCJu"
},
"source": [
"#### Predicting\n"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {
"id": "CJ_9qh6ONCJu"
},
"outputs": [],
"source": [
"y_pred = model.predict(X_test_tfidf)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "koeb78PsNCJu"
},
"source": [
"#### Classification report\n"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hsABx8mJNCJv",
"outputId": "c4c23ca6-c88a-4db9-fe66-36a7febd3594"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" precision recall f1-score support\n",
"\n",
" Irrelevant 0.82 0.70 0.75 2304\n",
" Negative 0.80 0.86 0.83 4024\n",
" Neutral 0.79 0.74 0.77 3169\n",
" Positive 0.78 0.82 0.80 3671\n",
"\n",
" accuracy 0.79 13168\n",
" macro avg 0.80 0.78 0.79 13168\n",
"weighted avg 0.79 0.79 0.79 13168\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y17ccTy1NCJv"
},
"source": [
"### TD-IDF - Random Forest Classifier\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yyk_baF-NCJv"
},
"source": [
"#### Importing libraries\n"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {
"id": "-xjXLHpQNCJv"
},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import classification_report"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tl6mOx92NCJw"
},
"source": [
"#### Text Vectorization Using TF-IDF\n"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {
"id": "bE9h15BcNCJw"
},
"outputs": [],
"source": [
"vectorizer = TfidfVectorizer()\n",
"X_train_tfidf = vectorizer.fit_transform(X_train)\n",
"X_test_tfidf = vectorizer.transform(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cNUGJWXINCJw"
},
"source": [
"#### Training a Random Forest Classifier model\n"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "WTrPtycbNCJw",
"outputId": "e97b690c-f698-414a-cc40-4843d12e2073"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomForestClassifier()"
],
"text/html": [
"<style>#sk-container-id-4 {color: black;background-color: white;}#sk-container-id-4 pre{padding: 0;}#sk-container-id-4 div.sk-toggleable {background-color: white;}#sk-container-id-4 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-4 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-4 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-4 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-4 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-4 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-4 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-4 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-4 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-4 div.sk-item {position: relative;z-index: 1;}#sk-container-id-4 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-4 div.sk-item::before, #sk-container-id-4 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-4 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-4 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-4 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-4 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-4 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-4 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-4 div.sk-label-container {text-align: center;}#sk-container-id-4 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-4 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-4\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" checked><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier()</pre></div></div></div></div></div>"
]
},
"metadata": {},
"execution_count": 128
}
],
"source": [
"model = RandomForestClassifier(criterion=\"gini\")\n",
"model.fit(X_train_tfidf, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPlAbp8PNCJx"
},
"source": [
"#### Predicting\n"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {
"id": "0ePAr1uZNCJx"
},
"outputs": [],
"source": [
"y_pred = model.predict(X_test_tfidf)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oPnSaSB-NCJx"
},
"source": [
"#### Classification report\n"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gqRJRLcKNCJx",
"outputId": "b4b1bbfb-5b76-4936-cb74-e200dc72e1c6"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" precision recall f1-score support\n",
"\n",
" Irrelevant 0.95 0.87 0.91 2304\n",
" Negative 0.92 0.95 0.93 4024\n",
" Neutral 0.94 0.91 0.93 3169\n",
" Positive 0.90 0.94 0.92 3671\n",
"\n",
" accuracy 0.93 13168\n",
" macro avg 0.93 0.92 0.92 13168\n",
"weighted avg 0.93 0.93 0.92 13168\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18jz3yhuNCJy"
},
"source": [
"### Word2Vec - LSTM\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0gizZeVCNCJy"
},
"source": [
"#### Installation of packages\n"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Sy0x-OwPNCJy",
"outputId": "9815f9df-920a-48c0-f8c3-c174b2544ee4"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.15.0)\n",
"Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0)\n",
"Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3)\n",
"Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.3.25)\n",
"Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.5.4)\n",
"Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n",
"Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.9.0)\n",
"Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (18.1.1)\n",
"Requirement already satisfied: ml-dtypes~=0.2.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n",
"Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.25.2)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.3.0)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.0)\n",
"Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.20.3)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (67.7.2)\n",
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n",
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.4.0)\n",
"Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.12.1)\n",
"Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.14.1)\n",
"Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.37.0)\n",
"Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.64.1)\n",
"Requirement already satisfied: tensorboard<2.16,>=2.15 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.15.2)\n",
"Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.15.0)\n",
"Requirement already satisfied: keras<2.16,>=2.15.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.15.0)\n",
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.43.0)\n",
"Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (2.27.0)\n",
"Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (1.2.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (3.6)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (2.31.0)\n",
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (0.7.2)\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (3.0.3)\n",
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (5.3.3)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (0.4.0)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (4.9)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow) (1.3.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (2024.6.2)\n",
"Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow) (2.1.5)\n",
"Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (0.6.0)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow) (3.2.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.25.2)\n"
]
}
],
"source": [
"%pip install tensorflow\n",
"%pip install numpy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nyh33SHPNCJy"
},
"source": [
"#### Importing libraries\n"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {
"id": "WGINcl6pNCJy"
},
"outputs": [],
"source": [
"from gensim.models import Word2Vec\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from sklearn.calibration import LabelEncoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JrQ66Il0NCJy"
},
"source": [
"#### Function to convert text to Word2Vec vectors\n"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {
"id": "3MEhNRL0NCJz"
},
"outputs": [],
"source": [
"def text_to_vector(text, word2vec, vector_size):\n",
" words = simple_preprocess(text)\n",
" text_vector = np.zeros(vector_size)\n",
" word_count = 0\n",
" for word in words:\n",
" if word in word2vec.wv:\n",
" text_vector += word2vec.wv[word]\n",
" word_count += 1\n",
" if word_count > 0:\n",
" text_vector /= word_count\n",
" return text_vector"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JH_t6_ZtNCJz"
},
"source": [
"#### Tokenize texts\n"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {
"id": "KVzBCEbWNCJz"
},
"outputs": [],
"source": [
"tokenized_text = dataset[\"content\"].apply(lambda x: x.split())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o_fLSc_uNCJz"
},
"source": [
"#### Vector size parameter\n"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {
"id": "sLY4J1nTNCJ0"
},
"outputs": [],
"source": [
"vector_size = 100"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XIMHHrRqNCJ0"
},
"source": [
"#### Train Word2Vec model\n"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {
"id": "UysosPtiNCJ1"
},
"outputs": [],
"source": [
"model_word2vec = Word2Vec(\n",
" tokenized_text, window=5, min_count=2, workers=4, vector_size=vector_size, epochs=20\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xufmoLDlNCJ1"
},
"source": [
"#### Convert texts to Word2Vec vectors\n"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {
"id": "QrY3vXcXNCJ1"
},
"outputs": [],
"source": [
"train_vectors = np.array(\n",
" [text_to_vector(text, model_word2vec, vector_size) for text in X_train]\n",
")\n",
"\n",
"test_vectors = np.array(\n",
" [text_to_vector(text, model_word2vec, vector_size) for text in X_test]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-3hER130NCJ2"
},
"source": [
"#### Find the maximum sequence length in the training set\n"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {
"id": "gcxbUr4lNCJ2"
},
"outputs": [],
"source": [
"max_len = max(len(seq) for seq in train_vectors)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ezJ1OadFNCJ3"
},
"source": [
"#### Pad sequences to the same length\n"
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {
"id": "1oGWbRZtNCJ3"
},
"outputs": [],
"source": [
"X_train_emb = tf.keras.preprocessing.sequence.pad_sequences(\n",
" train_vectors, maxlen=max_len, dtype=\"float32\", padding=\"post\"\n",
")\n",
"X_test_emb = tf.keras.preprocessing.sequence.pad_sequences(\n",
" test_vectors, maxlen=max_len, dtype=\"float32\", padding=\"post\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LaBGGtm3NCJ4"
},
"source": [
"#### Encode labels\n"
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {
"id": "Xe96PgKtNCJ4"
},
"outputs": [],
"source": [
"label_encoder = LabelEncoder()\n",
"y_train_enc = label_encoder.fit_transform(y_train)\n",
"y_test_enc = label_encoder.transform(y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P_fI-cZHNCJ4"
},
"source": [
"#### Define LSTM model\n"
]
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "sEUnVQJEP-hy"
}
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {
"id": "pF5HZSRKNCJ4"
},
"outputs": [],
"source": [
"model = tf.keras.Sequential(\n",
" [\n",
" tf.keras.layers.Embedding(input_dim=X_train_emb.shape[1], output_dim=100),\n",
" tf.keras.layers.LSTM(128),\n",
" tf.keras.layers.Dense(64, activation=\"relu\"),\n",
" tf.keras.layers.Dense(32, activation=\"relu\"),\n",
" tf.keras.layers.Dense(4, activation=\"softmax\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YNk8m5lnNCJ4"
},
"source": [
"#### Compile the model\n"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {
"id": "IqDLE1FuNCJ5"
},
"outputs": [],
"source": [
"model.compile(\n",
" optimizer=tf.optimizers.Adam(learning_rate=1e-3),\n",
" loss=\"sparse_categorical_crossentropy\",\n",
" metrics=[\"accuracy\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cIRnkLT0NCJ5"
},
"source": [
"#### Train the model\n"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QFpStkX9NCJ5",
"outputId": "af188e61-04a4-45e7-e6f3-ef50cba478da"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/50\n",
"823/823 [==============================] - 10s 9ms/step - loss: 1.3439 - accuracy: 0.3438\n",
"Epoch 2/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.3261 - accuracy: 0.3678\n",
"Epoch 3/50\n",
"823/823 [==============================] - 6s 8ms/step - loss: 1.3163 - accuracy: 0.3774\n",
"Epoch 4/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.3020 - accuracy: 0.3975\n",
"Epoch 5/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.2904 - accuracy: 0.4119\n",
"Epoch 6/50\n",
"823/823 [==============================] - 8s 9ms/step - loss: 1.2814 - accuracy: 0.4186\n",
"Epoch 7/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.2741 - accuracy: 0.4262\n",
"Epoch 8/50\n",
"823/823 [==============================] - 8s 9ms/step - loss: 1.2667 - accuracy: 0.4325\n",
"Epoch 9/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.2588 - accuracy: 0.4372\n",
"Epoch 10/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.2513 - accuracy: 0.4407\n",
"Epoch 11/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.2451 - accuracy: 0.4450\n",
"Epoch 12/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 1.2365 - accuracy: 0.4491\n",
"Epoch 13/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.2291 - accuracy: 0.4560\n",
"Epoch 14/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.2218 - accuracy: 0.4593\n",
"Epoch 15/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.2144 - accuracy: 0.4636\n",
"Epoch 16/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.2066 - accuracy: 0.4669\n",
"Epoch 17/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.1989 - accuracy: 0.4707\n",
"Epoch 18/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.1887 - accuracy: 0.4759\n",
"Epoch 19/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.1810 - accuracy: 0.4803\n",
"Epoch 20/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.1717 - accuracy: 0.4846\n",
"Epoch 21/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.1631 - accuracy: 0.4883\n",
"Epoch 22/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 1.1533 - accuracy: 0.4948\n",
"Epoch 23/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.1426 - accuracy: 0.4983\n",
"Epoch 24/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.1338 - accuracy: 0.5040\n",
"Epoch 25/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.1229 - accuracy: 0.5075\n",
"Epoch 26/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 1.1126 - accuracy: 0.5125\n",
"Epoch 27/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.1042 - accuracy: 0.5167\n",
"Epoch 28/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 1.0920 - accuracy: 0.5237\n",
"Epoch 29/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.0809 - accuracy: 0.5266\n",
"Epoch 30/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 1.0730 - accuracy: 0.5307\n",
"Epoch 31/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.0628 - accuracy: 0.5357\n",
"Epoch 32/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.0536 - accuracy: 0.5422\n",
"Epoch 33/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.0399 - accuracy: 0.5480\n",
"Epoch 34/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.0350 - accuracy: 0.5503\n",
"Epoch 35/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.0237 - accuracy: 0.5553\n",
"Epoch 36/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 1.0217 - accuracy: 0.5550\n",
"Epoch 37/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 1.0073 - accuracy: 0.5633\n",
"Epoch 38/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 0.9927 - accuracy: 0.5703\n",
"Epoch 39/50\n",
"823/823 [==============================] - 6s 8ms/step - loss: 0.9848 - accuracy: 0.5732\n",
"Epoch 40/50\n",
"823/823 [==============================] - 6s 8ms/step - loss: 0.9786 - accuracy: 0.5748\n",
"Epoch 41/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 0.9735 - accuracy: 0.5774\n",
"Epoch 42/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 0.9633 - accuracy: 0.5839\n",
"Epoch 43/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 0.9530 - accuracy: 0.5873\n",
"Epoch 44/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 0.9506 - accuracy: 0.5893\n",
"Epoch 45/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 0.9364 - accuracy: 0.5958\n",
"Epoch 46/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 0.9260 - accuracy: 0.6006\n",
"Epoch 47/50\n",
"823/823 [==============================] - 7s 9ms/step - loss: 0.9257 - accuracy: 0.6008\n",
"Epoch 48/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 0.9155 - accuracy: 0.6048\n",
"Epoch 49/50\n",
"823/823 [==============================] - 7s 8ms/step - loss: 0.9103 - accuracy: 0.6066\n",
"Epoch 50/50\n",
"823/823 [==============================] - 6s 7ms/step - loss: 0.8999 - accuracy: 0.6122\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.src.callbacks.History at 0x790a27fa2b60>"
]
},
"metadata": {},
"execution_count": 155
}
],
"source": [
"model.fit(X_train_emb, y_train_enc, epochs=50, batch_size=64)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CAoGoGZ7NCJ5"
},
"source": [
"#### Predicting\n"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LCtJlNP9NCJ5",
"outputId": "15942b31-270a-4817-9b55-f5a48663dbed"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"412/412 [==============================] - 2s 4ms/step\n"
]
}
],
"source": [
"y_pred = model.predict(X_test_emb)\n",
"\n",
"y_preds_argmax = []\n",
"for i in range(len(y_pred)):\n",
" y_preds_argmax.append(y_pred[i].argmax())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ee3GIUHJNCJ6"
},
"source": [
"#### Classification report\n"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MMCmZDLgNCJ6",
"outputId": "67ad7bb2-386c-432f-dd95-0c3105a13f0c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.32 0.20 0.25 2304\n",
" 1 0.46 0.62 0.53 4024\n",
" 2 0.44 0.43 0.44 3169\n",
" 3 0.45 0.39 0.42 3671\n",
"\n",
" accuracy 0.44 13168\n",
" macro avg 0.42 0.41 0.41 13168\n",
"weighted avg 0.43 0.44 0.42 13168\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_test_enc, y_preds_argmax))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}