539 lines
27 KiB
Plaintext
539 lines
27 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n",
|
|
"import torch\n",
|
|
"from transformers import TrainingArguments, Trainer\n",
|
|
"from transformers import BertTokenizer, BertForSequenceClassification\n",
|
|
"from transformers import EarlyStoppingCallback\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>id</th>\n",
|
|
" <th>gender</th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>topic</th>\n",
|
|
" <th>sign</th>\n",
|
|
" <th>date</th>\n",
|
|
" <th>text</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>14,May,2004</td>\n",
|
|
" <td>Info has been found (+/- 100 pages,...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>13,May,2004</td>\n",
|
|
" <td>These are the team members: Drewe...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>12,May,2004</td>\n",
|
|
" <td>In het kader van kernfusie op aarde...</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>12,May,2004</td>\n",
|
|
" <td>testing!!! testing!!!</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>3581210</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>33</td>\n",
|
|
" <td>InvestmentBanking</td>\n",
|
|
" <td>Aquarius</td>\n",
|
|
" <td>11,June,2004</td>\n",
|
|
" <td>Thanks to Yahoo!'s Toolbar I can ...</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" id gender age topic sign date \\\n",
|
|
"0 2059027 male 15 Student Leo 14,May,2004 \n",
|
|
"1 2059027 male 15 Student Leo 13,May,2004 \n",
|
|
"2 2059027 male 15 Student Leo 12,May,2004 \n",
|
|
"3 2059027 male 15 Student Leo 12,May,2004 \n",
|
|
"4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n",
|
|
"\n",
|
|
" text \n",
|
|
"0 Info has been found (+/- 100 pages,... \n",
|
|
"1 These are the team members: Drewe... \n",
|
|
"2 In het kader van kernfusie op aarde... \n",
|
|
"3 testing!!! testing!!! \n",
|
|
"4 Thanks to Yahoo!'s Toolbar I can ... "
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"data = pd.read_csv(\"data/blogtext.csv\")\n",
|
|
"data = data[:100]\n",
|
|
"data.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model typu encoder (BertForSequenceClassification)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']\n",
|
|
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
|
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
|
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
|
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model_name = 'bert-base-uncased'\n",
|
|
"tokenizer = BertTokenizer.from_pretrained(model_name)\n",
|
|
"model = BertForSequenceClassification.from_pretrained(model_name, problem_type=\"multi_label_classification\", num_labels=4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbqElEQVR4nO3df5RcZZ3n8feHhAShFQhgL0KUCMnOCeMsY5oEZpXpFgcadiTObJhN88OwA5tVJ3N29Lizcd2JJLp7Dq4L4x5wNLOwMIDpZOPqZLUdQEkfRpcfIcqvBhKaiJKIsNBELBFj4Lt/3CdOTVnVVd1V1dU8+bzOqdP3x/Pc+62bW5+6/VT1jSICMzPL1yGdLsDMzNrLQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvXWcpBFJvZ2uo5Mk/YGkpyWVJP12p+uxvDjora0kPSXpvRXLLpP07QPzEXFqRAzX2c5JkkLSzDaV2mmfBVZFRFdEfK9aAxV2SXp0imuz1zkHvRkwDd5A3gaM1GlzFvBm4O2STm9/SZYLB711XPlVv6TFku6X9JKkZyVdnZrdlX7uTcMbZ0o6RNJ/kvQDSc9J+htJR5Zt9wNp3QuS/qJiP1dK2izpFkkvAZelfd8taa+kZyRdK2lW2fZC0oclPSHpp5I+JelkSf831bupvH3Fc6xaq6TZkkrADOBBSU+Oc6hWAH8LDKXp8u3Pk3RXquubkq6TdEvZ+jNSnXslPXiwD5UdbBz0Nt18DvhcRLwJOBnYlJaflX4elYY37gYuS48+4O1AF3AtgKSFwOeBi4HjgSOBEyr2tRTYDBwF3Aq8CnwEOBY4Ezgb+HBFn3OBRcAZwJ8D64FLgLnAbwIDNZ5X1Voj4hcR0ZXa/LOIOLlaZ0mHA8tSnbcCyyveVL4E3AccA1wJXFrW9wTg68CngTnAx4AvSzquRq2WGQe9TYWvpivJvZL2UgRwLb8ETpF0bESUIuKecdpeDFwdEbsiogR8nCIAZ1KE4v+JiG9HxD5gDVB5Y6e7I+KrEfFaRPw8IrZHxD0RsT8ingK+CPxuRZ/PRMRLETECPALcnvb/E+AbQK0PUsertRF/CPwCuJ0itA8F/gWApLcCpwNrImJfRHwb2FLW9xJgKCKG0nO9A7gfOL/BfdvrnIPepsL7I+KoAw9+/Sq53OXAAuBxSdsk/f44bd8C/KBs/gfATKA7rXv6wIqIeBl4oaL/0+UzkhZI+pqkH6fhnP9CcXVf7tmy6Z9Xme+iuvFqbcQKYFN6E3oF+DL/MHzzFmAsPccDyp/b24ALK95s30Xxm44dBDr9AZTZPxIRTwADkg6huIrdLOkYfv1qHOBHFCF2wFuB/RTh+wzwTw+skPQGimGNf7S7ivm/Ar4HDETETyX9GcVvBq0wXq3jknQi8B5gsaR/mRYfDhwm6ViK5zpH0uFlYT+3bBNPAzdHxL9p8jnY65Sv6G1akXSJpOMi4jVgb1r8GvD/0s+3lzXfAHwkfRDZRXEFvjEi9lOMvb9P0u+ksewrAdXZ/RuBl4CSpN8APtSip1Wv1nouBXZSvHGdlh4LgN0Ub0o/oBiKuVLSLElnAu8r638LxbE4V9IMSYdJ6k1vIHYQcNDbdNMPjKRvonwOWJ7Gz18G/jPwnTT8cAZwA3AzxTdyvg+8AvwpQBpD/1NgkOKKtwQ8RzHOXcvHgIuAnwJ/DWxs4fOqWWsDVgCfj4gflz+AL/APwzcXU3yA/ALFh64bSc81Ip6m+OD5P1K8YT4N/Hv8+j9oyP/xiB0M0lX0XmB+RHy/w+W0naSNwOMR8clO12Kd53d0y5ak90k6XNIRFH95+jDwVGerag9Jp6fv9B8iqZ/iCv6rHS7LpgkHveVsKcWHoD8C5lMMA+X6K+w/AYYphqj+O/ChWrdSsIOPh27MzDLnK3ozs8xNu+/RH3vssXHSSSd1bP8/+9nPOOKIIzq2/3pcX3NcX3NcX3PaWd/27dufj4jqt7WIiGn1WLRoUXTS1q1bO7r/elxfc1xfc1xfc9pZH3B/1MhVD92YmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWVu2t0Cwczy0dc3tfsbGIC1a6d2nxNRr76tW9uzX1/Rm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpa5hoJeUr+kHZJGJa2usv4sSd+VtF/Ssirr3yRpt6RrW1G0mZk1rm7QS5oBXAecBywEBiQtrGj2Q+Ay4Es1NvMp4K7Jl2lmZpPVyBX9YmA0InZFxD5gEFha3iAinoqIh4DXKjtLWgR0A7e3oF4zM5sgFf/V4DgNiqGY/oi4Is1fCiyJiFVV2t4IfC0iNqf5Q4A7gUuA9wI9NfqtBFYCdHd3LxocHGzmOTWlVCrR1dXVsf3X4/qa4/qaM9H6du5sYzFVzJlTYmxs+h6/evUtWDD5bff19W2PiJ5q69p9C4QPA0MRsVtSzUYRsR5YD9DT0xO9vb1tLqu24eFhOrn/elxfc1xfcyZa31TfjmBgYJgNG3qndqcTUK++dt0CoZGg3wPMLZs/MS1rxJnAuyV9GOgCZkkqRcSvfaBrZmbt0UjQbwPmS5pHEfDLgYsa2XhEXHxgWtJlFEM3DnkzsylU98PYiNgPrAJuAx4DNkXEiKR1ki4AkHS6pN3AhcAXJY20s2gzM2tcQ2P0ETEEDFUsW1M2vY1iSGe8bdwI3DjhCs3MrCn+y1gzs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLXENBL6lf0g5Jo5JWV1l/lqTvStovaVnZ8tMk3S1pRNJDkv5VK4s3M7P66ga9pBnAdcB5wEJgQNLCimY/BC4DvlSx/GXgAxFxKtAP/KWko5qs2czMJmBmA20WA6MRsQtA0iCwFHj0QIOIeCqte628Y0TsLJv+kaTngOOAvc0WbmZmjVFEjN+gGIrpj4gr0vylwJKIWFWl7Y3A1yJic5V1i4GbgFMj4rWKdSuBlQDd3d2LBgcHJ/dsWqBUKtHV1dWx/dfj+prj+poz0fp27qzfppXmzCkxNjZ9j1+9+hYsmPy2+/r6tkdET7V1jVzRN03S8cDNwIrKkAeIiPXAeoCenp7o7e2dirKqGh4eppP7r8f1Ncf1NWei9a1d275aqhkYGGbDht6p3ekE1Ktv69b27LeRD2P3AHPL5k9Myxoi6U3A14FPRMQ9EyvPzMya1UjQbwPmS5onaRawHNjSyMZT+68Af1NtOMfMzNqvbtBHxH5gFXAb8BiwKSJGJK2TdAGApNMl7QYuBL4oaSR1/yPgLOAySQ+kx2nteCJmZlZdQ2P0ETEEDFUsW1M2vY1iSKey3y3ALU3WaGZmTfBfxpqZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmGgp6Sf2SdkgalbS6yvqzJH1X0n5JyyrWrZD0RHqsaFXhZmbWmLpBL2kGcB1wHrAQGJC0sKLZD4HLgC9V9J0DfBJYAiwGPinp6ObLNjOzRjVyRb8YGI2IXRGxDxgElpY3iIinIuIh4LWKvucCd0TEWES8CNwB9LegbjMza5AiYvwGxVBMf0RckeYvBZZExKoqbW8EvhYRm9P8x4DDIuLTaf4vgJ9HxGcr+q0EVgJ0d3cvGhwcbPZ5TVqpVKKrq6tj+6/H9TXH9TVnovXt3NnGYqqYM6fE2Nj0PX716luwYPLb7uvr2x4RPdXWzZz8ZlsnItYD6wF6enqit7e3Y7UMDw/Tyf3X4/qa4/qaM9H61q5tXy3VDAwMs2FD79TudALq1bd1a3v228jQzR5gbtn8iWlZI5rpa2ZmLdBI0G8D5kuaJ2kWsBzY0uD2bwPOkXR0+hD2nLTMzMymSN2gj4j9wCqKgH4M2BQRI5LWSboAQNLpknYDFwJflDSS+o4Bn6J4s9gGrEvLzMxsijQ0Rh8RQ8BQxbI1ZdPbKIZlqvW9AbihiRrNzKwJ/stYM7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy1xDQS+pX9IOSaOSVldZP1vSxrT+XkknpeWHSrpJ0sOSHpP08RbXb2ZmddQNekkzgOuA84CFwICkhRXNLgdejIhTgGuAq9LyC4HZEfEOYBHwbw+8CZiZ2dRo5Ip+MTAaEbsiYh8wCCytaLMUuClNbwbOliQggCMkzQTeAOwDXmpJ5WZm1hBFxPgNpGVAf0RckeYvBZZExKqyNo+kNrvT/JPAEuAnwM3A2cDhwEciYn2VfawEVgJ0d3cvGhwcbMFTm5xSqURXV1fH9l+P62uO62vOROvbubONxVQxZ06JsbHpe/zq1bdgweS33dfXtz0ieqqtmzn5zTZkMfAq8BbgaODvJX0zInaVN0rhvx6gp6cnent721xWbcPDw3Ry//W4vua4vuZMtL61a9tXSzUDA8Ns2NA7tTudgHr1bd3anv02MnSzB5hbNn9iWla1TRqmORJ4AbgI+LuI+GVEPAd8B6j6jmNmZu3RSNBvA+ZLmidpFrAc2FLRZguwIk0vA+6MYkzoh8B7ACQdAZwBPN6Kws3MrDF1gz4i9gOrgNuAx4BNETEiaZ2kC1Kz64FjJI0CHwUOfAXzOqBL0gjFG8b/jIiHWv0kzMystobG6CNiCBiqWLambPoViq9SVvYrVVtuZmZTx38Za2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplrKOgl9UvaIWlU0uoq62dL2pjW3yvppLJ1vyXpbkkjkh6WdFgL6zczszrqBr2kGcB1wHnAQmBA0sKKZpcDL0bEKcA1wFWp70zgFuCDEXEq0Av8smXVm5lZXY1c0S8GRiNiV0TsAwaBpRVtlgI3penNwNmSBJwDPBQRDwJExAsR8WprSjczs0YoIsZvIC0D+iPiijR/KbAkIlaVtXkktdmd5p8ElgCXAIuANwPHAYMR8Zkq+1gJrATo7u5eNDg42IKnNjmlUomurq6O7b8e19cc19ecida3c2cbi6lizpwSY2PT9/jVq2/Bgslvu6+vb3tE9FRbN3Pym23ITOBdwOnAy8C3JG2PiG+VN4qI9cB6gJ6enujt7W1zWbUNDw/Tyf3X4/qa4/qaM9H61q5tXy3VDAwMs2FD79TudALq1bd1a3v220jQ7wHmls2fmJZVa7M7jcsfCbwA7AbuiojnASQNAe8EvkWb9PU1139gYOpPzolodX3tOrHMbPpoZIx+GzBf0jxJs4DlwJaKNluAFWl6GXBnFGNCtwHvkHR4egP4XeDR1pRuZmaNqHtFHxH7Ja2iCO0ZwA0RMSJpHXB/RGwBrgduljQKjFG8GRARL0q6muLNIoChiPh6m56LmZlV0dAYfUQMAUMVy9aUTb8CXFij7y0UX7E0M7MO8F/GmpllzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYaCnpJ/ZJ2SBqVtLrK+tmSNqb190o6qWL9WyWVJH2sRXWbmVmD6ga9pBnAdcB5wEJgQNLCimaXAy9GxCnANcBVFeuvBr7RfLlmZjZRjVzRLwZGI2JXROwDBoGlFW2WAjel6c3A2ZIEIOn9wPeBkZZUbGZmE9JI0J8APF02vzstq9omIvYDPwGOkdQF/AdgbfOlmpnZZCgixm8gLQP6I+KKNH8psCQiVpW1eSS12Z3mnwSWAKuB+yJik6QrgVJEfLbKPlYCKwG6u7sXDQ4OTvoJ7dw56a4AzJlTYmysq7mNtFGr61uwoGWbAqBUKtHVNX2Pn+trzkTra/b1OFGv99dvM6/Hvr6+7RHRU23dzAb67wHmls2fmJZVa7Nb0kzgSOAFirBfJukzwFHAa5JeiYhryztHxHpgPUBPT0/09vY2UFZ1a5v83WFgYJgNGya//3ZrdX1bt7ZsUwAMDw/TzL9fu7m+5ky0vmZfjxP1en/9tvr1eEAjQb8NmC9pHkWgLwcuqmizBVgB3A0sA+6M4leFdx9oUHZFfy1mZjZl6gZ9ROyXtAq4DZgB3BARI5LWAfdHxBbgeuBmSaPAGMWbgZmZTQONXNETEUPAUMWyNWXTrwAX1tnGlZOoz8zMmuS/jDUzy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLXUNBL6pe0Q9KopNVV1s+WtDGtv1fSSWn570naLunh9PM9La7fzMzqqBv0kmYA1wHnAQuBAUkLK5pdDrwYEacA1wBXpeXPA++LiHcAK4CbW1W4mZk1ppEr+sXAaETsioh9wCCwtKLNUuCmNL0ZOFuSIuJ7EfGjtHwEeIOk2a0o3MzMGqOIGL+BtAzoj4gr0vylwJKIWFXW5pHUZneafzK1eb5iOx+MiPdW2cdKYCVAd3f3osHBwUk/oZ07J90VgDlzSoyNdTW3kTZqdX0LFrRsUwCUSiW6uqbv8XN9zZlofc2+Hifq9f76beb12NfXtz0ieqqtmzn5zTZO0qkUwznnVFsfEeuB9QA9PT3R29s76X2tXTvprgAMDAyzYcPk999ura5v69aWbQqA4eFhmvn3azfX15yJ1tfs63GiXu+v31a/Hg9oZOhmDzC3bP7EtKxqG0kzgSOBF9L8icBXgA9ExJPNFmxmZhPTSNBvA+ZLmidpFrAc2FLRZgvFh60Ay4A7IyIkHQV8HVgdEd9pUc1mZjYBdYM+IvYDq4DbgMeATRExImmdpAtSs+uBYySNAh8FDnwFcxVwCrBG0gPp8eaWPwszM6upoTH6iBgChiqWrSmbfgW4sEq/TwOfbrJGMzNrgv8y1swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy9yU/OfgZrno62vv9gcGpv4/1J6I6V6fVecrejOzzDnozcwy11DQS+qXtEPSqKTVVdbPlrQxrb9X0kll6z6elu+QdG4LazczswbUDXpJM4DrgPOAhcCApIUVzS4HXoyIU4BrgKtS34XAcuBUoB/4fNqemZlNkUau6BcDoxGxKyL2AYPA0oo2S4Gb0vRm4GxJSssHI+IXEfF9YDRtz8zMpkgj37o5AXi6bH43sKRWm4jYL+knwDFp+T0VfU+o3IGklcDKNFuStKOh6ttgeJhjgec7tf96Wl2f1Kot/cq0Pn5M8/oOtvOv1V7v9TX5enxbrRXT4uuVEbEeWN/pOgAk3R8RPZ2uoxbX1xzX1xzX15xO1dfI0M0eYG7Z/IlpWdU2kmYCRwIvNNjXzMzaqJGg3wbMlzRP0iyKD1e3VLTZAqxI08uAOyMi0vLl6Vs584D5wH2tKd3MzBpRd+gmjbmvAm4DZgA3RMSIpHXA/RGxBbgeuFnSKDBG8WZAarcJeBTYD/xJRLzapufSKtNiCGkcrq85rq85rq85HalPxYW3mZnlyn8Za2aWOQe9mVnmDpqgl3SDpOckPVK27EpJeyQ9kB7n1+g77i0g2ljfxrLanpL0QI2+T0l6OLW7vw21zZW0VdKjkkYk/bu0fI6kOyQ9kX4eXaP/itTmCUkrqrVpU33/VdLjkh6S9BVJR9Xo39bjV6fGjp+D49Q2Lc6/tI/DJN0n6cFU49q0fF667cpoqndWjf5tvRXLOPXdmvb5SHqNH1qj/6tlx7ryyy7Ni4iD4gGcBbwTeKRs2ZXAx+r0mwE8CbwdmAU8CCycivoq1v83YE2NdU8Bx7bx2B0PvDNNvxHYSXE7jM8Aq9Py1cBVVfrOAXaln0en6aOnqL5zgJlp+VXV6puK41enxo6fg7Vqmy7nX9qHgK40fShwL3AGsAlYnpZ/AfhQlb4L0zGbDcxLx3LGFNV3flonYEO1+lKfUjuP30FzRR8Rd1F8I2iiGrkFRNPGq0+SgD+iOFGmXEQ8ExHfTdM/BR6j+Avn8ltf3AS8v0r3c4E7ImIsIl4E7qC471Hb64uI2yNif2p2D8XfcXTEOMewEW09B+vV1unzL9UVEVFKs4emRwDvobjtCtQ+B9t+K5Za9UXEUFoXFF8t78g5eNAE/ThWpV/tb6gx9FDtFhCNvkBb5d3AsxHxRI31AdwuabuK20m0jYo7k/42xRVLd0Q8k1b9GOiu0mVKj19FfeX+GPhGjW5Tdvygao3T5hyscfymxfknaUYaPnqO4oLhSWBv2Zt5reMyJcevsr6IuLds3aHApcDf1eh+mKT7Jd0j6f2tru1gD/q/Ak4GTgOeofj1dDoaYPyrqXdFxDsp7jD6J5LOakcRkrqALwN/FhEvla9LVywd/a5urfokfYLi7zhurdF1So5fjRqnzTk4zr/vtDj/IuLViDiN4qp4MfAb7djPZFXWJ+k3y1Z/HrgrIv6+Rve3RXFrhIuAv5R0citrO6iDPiKeTf84rwF/TfVf5zp6GwcVt5T4Q2BjrTYRsSf9fA74Cm24Q2i6IvkycGtE/O+0+FlJx6f1x1NcyVSakuNXoz4kXQb8PnBxejP6NVNx/GrVOF3OwXGO37Q4/yr2txfYCpwJHJVqhNrHZUpfw2X19QNI+iRwHPDRcfocOIa7gGGK36pa5qAO+gMhlfwB8EiVZo3cAqKd3gs8HhG7q62UdISkNx6YpvgAstrzmLQ0Rns98FhEXF22qvzWFyuAv63S/TbgHElHp2GJc9KyttcnqR/4c+CCiHi5Rt+2H786NXb8HBzn3xemwfmXtn2c0remJL0B+D2KzxK2Utx2BWqfg22/FUuN+h6XdAXF51QD6c28Wt+jJc1O08cC/5zibgKt085PeqfTg+JXz2eAX1KM0V0O3Aw8DDxEcTIcn9q+BRgq63s+xTcRngQ+MVX1peU3Ah+saPur+ii+ifFgeoy0oz7gXRTDMg8BD6TH+RS3ov4W8ATwTWBOat8D/I+y/n9M8QHYKPCvp7C+UYqx2QPLvtCJ41enxo6fg7Vqmy7nX9rPbwHfSzU+QvoGUNr/fenf+n8Bs9PyC4B1Zf0/kY7dDuC8Kaxvf9rvgeN6YPmvXiPA76Rz4MH08/JW1+dbIJiZZe6gHroxMzsYOOjNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy9z/B8C5FPdY0UE0AAAAAElFTkSuQmCC",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<Figure size 10000x10000 with 0 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"n, bins, patches = plt.hist(data['age'], 4, density=True, facecolor='b', alpha=0.75)\n",
|
|
"\n",
|
|
"plt.title('Histogram of Age')\n",
|
|
"plt.grid(True)\n",
|
|
"plt.figure(figsize=(100,100), dpi=100)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>id</th>\n",
|
|
" <th>gender</th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>topic</th>\n",
|
|
" <th>sign</th>\n",
|
|
" <th>date</th>\n",
|
|
" <th>text</th>\n",
|
|
" <th>label</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>14,May,2004</td>\n",
|
|
" <td>Info has been found (+/- 100 pages,...</td>\n",
|
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>13,May,2004</td>\n",
|
|
" <td>These are the team members: Drewe...</td>\n",
|
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>12,May,2004</td>\n",
|
|
" <td>In het kader van kernfusie op aarde...</td>\n",
|
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>2059027</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>15</td>\n",
|
|
" <td>Student</td>\n",
|
|
" <td>Leo</td>\n",
|
|
" <td>12,May,2004</td>\n",
|
|
" <td>testing!!! testing!!!</td>\n",
|
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>3581210</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>33</td>\n",
|
|
" <td>InvestmentBanking</td>\n",
|
|
" <td>Aquarius</td>\n",
|
|
" <td>11,June,2004</td>\n",
|
|
" <td>Thanks to Yahoo!'s Toolbar I can ...</td>\n",
|
|
" <td>[0.0, 0.0, 1.0, 0.0]</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" id gender age topic sign date \\\n",
|
|
"0 2059027 male 15 Student Leo 14,May,2004 \n",
|
|
"1 2059027 male 15 Student Leo 13,May,2004 \n",
|
|
"2 2059027 male 15 Student Leo 12,May,2004 \n",
|
|
"3 2059027 male 15 Student Leo 12,May,2004 \n",
|
|
"4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n",
|
|
"\n",
|
|
" text label \n",
|
|
"0 Info has been found (+/- 100 pages,... [1.0, 0.0, 0.0, 0.0] \n",
|
|
"1 These are the team members: Drewe... [1.0, 0.0, 0.0, 0.0] \n",
|
|
"2 In het kader van kernfusie op aarde... [1.0, 0.0, 0.0, 0.0] \n",
|
|
"3 testing!!! testing!!! [1.0, 0.0, 0.0, 0.0] \n",
|
|
"4 Thanks to Yahoo!'s Toolbar I can ... [0.0, 0.0, 1.0, 0.0] "
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\"\"\"\n",
|
|
"1 - 22 -> 1 klasa\n",
|
|
"23 - 31 -> 2 klasa\n",
|
|
"32 - 39 -> 3 klasa \n",
|
|
"40 - 48 -> 4 klasa\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"def mapAgeToClass(value: pd.DataFrame) -> int:\n",
|
|
" if(value['age'] <=22):\n",
|
|
" return 1\n",
|
|
" elif(value['age'] > 22 and value['age'] <= 31):\n",
|
|
" return 2\n",
|
|
" elif(value['age'] > 31 and value['age'] <= 39):\n",
|
|
" return 3\n",
|
|
" else:\n",
|
|
" return 4\n",
|
|
"\n",
|
|
"def mapAgeToClass2(value: pd.DataFrame) -> int:\n",
|
|
" if(value['age'] <=22):\n",
|
|
" return [1.0,0.0,0.0,0.0]\n",
|
|
" elif(value['age'] > 22 and value['age'] <= 31):\n",
|
|
" return [0.0,1.0,0.0,0.0]\n",
|
|
" elif(value['age'] > 31 and value['age'] <= 39):\n",
|
|
" return [0.0,0.0,1.0,0.0]\n",
|
|
" else:\n",
|
|
" return [0.0,0.0,0.0,1.0]\n",
|
|
" \n",
|
|
"data['label'] = data.apply(lambda row: mapAgeToClass2(row), axis=1)\n",
|
|
"data.head()\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X = list(data['text'])\n",
|
|
"Y = list(data['label'])\n",
|
|
"if (torch.cuda.is_available()):\n",
|
|
" device = \"cuda:0\"\n",
|
|
" torch.cuda.empty_cache()\n",
|
|
"else:\n",
|
|
" device = \"cpu\"\n",
|
|
"device = \"cpu\"\n",
|
|
"\n",
|
|
"# model = model.to(device)\n",
|
|
"\n",
|
|
"X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.2)\n",
|
|
"X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)\n",
|
|
"# .to(device)\n",
|
|
"X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)\n",
|
|
"# .to(device)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Dataset(torch.utils.data.Dataset):\n",
|
|
" def __init__(self, encodings, labels=None):\n",
|
|
" self.encodings = encodings\n",
|
|
" self.labels = labels\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
|
|
" if self.labels:\n",
|
|
" item[\"labels\"] = torch.tensor(self.labels[idx])\n",
|
|
" return item\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.encodings[\"input_ids\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataset = Dataset(X_train_tokenized, y_train)\n",
|
|
"val_dataset = Dataset(X_val_tokenized, y_val)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def compute_metrics(p):\n",
|
|
" pred, labels = p\n",
|
|
" pred = np.argmax(pred, axis=1)\n",
|
|
"\n",
|
|
" accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
|
|
" recall = recall_score(y_true=labels, y_pred=pred)\n",
|
|
" precision = precision_score(y_true=labels, y_pred=pred)\n",
|
|
" f1 = f1_score(y_true=labels, y_pred=pred)\n",
|
|
"\n",
|
|
" return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PyTorch: setting up devices\n",
|
|
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"args = TrainingArguments(\n",
|
|
" output_dir=\"output\",\n",
|
|
" evaluation_strategy=\"steps\",\n",
|
|
" eval_steps=500,\n",
|
|
" per_device_train_batch_size=8,\n",
|
|
" per_device_eval_batch_size=8,\n",
|
|
" num_train_epochs=3,\n",
|
|
" seed=0,\n",
|
|
" load_best_model_at_end=True,\n",
|
|
" no_cuda=True\n",
|
|
")\n",
|
|
"trainer = Trainer(\n",
|
|
" model=model,\n",
|
|
" args=args,\n",
|
|
" train_dataset=train_dataset,\n",
|
|
" eval_dataset=val_dataset,\n",
|
|
" compute_metrics=compute_metrics,\n",
|
|
" callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/ramon/projects/projekt_glebokie/venv/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
|
" FutureWarning,\n",
|
|
"***** Running training *****\n",
|
|
" Num examples = 80\n",
|
|
" Num Epochs = 3\n",
|
|
" Instantaneous batch size per device = 8\n",
|
|
" Total train batch size (w. parallel, distributed & accumulation) = 8\n",
|
|
" Gradient Accumulation steps = 1\n",
|
|
" Total optimization steps = 30\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"trainer.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"raw_pred, _, _ = trainer.predict(val_dataset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y_pred = np.argmax(raw_pred, axis=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model typu decoder"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"interpreter": {
|
|
"hash": "f4394274b6de412f99b9d08dfb473204abc12afd5637ebb20c9ad8dbd67e97a0"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3.10.1 64-bit ('venv': venv)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.12"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|