539 lines
27 KiB
Plaintext
539 lines
27 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"import pandas as pd\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n",
|
||
|
"import torch\n",
|
||
|
"from transformers import TrainingArguments, Trainer\n",
|
||
|
"from transformers import BertTokenizer, BertForSequenceClassification\n",
|
||
|
"from transformers import EarlyStoppingCallback\n",
|
||
|
"import matplotlib.pyplot as plt"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>id</th>\n",
|
||
|
" <th>gender</th>\n",
|
||
|
" <th>age</th>\n",
|
||
|
" <th>topic</th>\n",
|
||
|
" <th>sign</th>\n",
|
||
|
" <th>date</th>\n",
|
||
|
" <th>text</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>14,May,2004</td>\n",
|
||
|
" <td>Info has been found (+/- 100 pages,...</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>13,May,2004</td>\n",
|
||
|
" <td>These are the team members: Drewe...</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>12,May,2004</td>\n",
|
||
|
" <td>In het kader van kernfusie op aarde...</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>12,May,2004</td>\n",
|
||
|
" <td>testing!!! testing!!!</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>3581210</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>33</td>\n",
|
||
|
" <td>InvestmentBanking</td>\n",
|
||
|
" <td>Aquarius</td>\n",
|
||
|
" <td>11,June,2004</td>\n",
|
||
|
" <td>Thanks to Yahoo!'s Toolbar I can ...</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" id gender age topic sign date \\\n",
|
||
|
"0 2059027 male 15 Student Leo 14,May,2004 \n",
|
||
|
"1 2059027 male 15 Student Leo 13,May,2004 \n",
|
||
|
"2 2059027 male 15 Student Leo 12,May,2004 \n",
|
||
|
"3 2059027 male 15 Student Leo 12,May,2004 \n",
|
||
|
"4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n",
|
||
|
"\n",
|
||
|
" text \n",
|
||
|
"0 Info has been found (+/- 100 pages,... \n",
|
||
|
"1 These are the team members: Drewe... \n",
|
||
|
"2 In het kader van kernfusie op aarde... \n",
|
||
|
"3 testing!!! testing!!! \n",
|
||
|
"4 Thanks to Yahoo!'s Toolbar I can ... "
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"data = pd.read_csv(\"data/blogtext.csv\")\n",
|
||
|
"data = data[:100]\n",
|
||
|
"data.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Model typu encoder (BertForSequenceClassification)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']\n",
|
||
|
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||
|
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||
|
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
|
||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model_name = 'bert-base-uncased'\n",
|
||
|
"tokenizer = BertTokenizer.from_pretrained(model_name)\n",
|
||
|
"model = BertForSequenceClassification.from_pretrained(model_name, problem_type=\"multi_label_classification\", num_labels=4)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 55,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbqElEQVR4nO3df5RcZZ3n8feHhAShFQhgL0KUCMnOCeMsY5oEZpXpFgcadiTObJhN88OwA5tVJ3N29Lizcd2JJLp7Dq4L4x5wNLOwMIDpZOPqZLUdQEkfRpcfIcqvBhKaiJKIsNBELBFj4Lt/3CdOTVnVVd1V1dU8+bzOqdP3x/Pc+62bW5+6/VT1jSICMzPL1yGdLsDMzNrLQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvXWcpBFJvZ2uo5Mk/YGkpyWVJP12p+uxvDjora0kPSXpvRXLLpP07QPzEXFqRAzX2c5JkkLSzDaV2mmfBVZFRFdEfK9aAxV2SXp0imuz1zkHvRkwDd5A3gaM1GlzFvBm4O2STm9/SZYLB711XPlVv6TFku6X9JKkZyVdnZrdlX7uTcMbZ0o6RNJ/kvQDSc9J+htJR5Zt9wNp3QuS/qJiP1dK2izpFkkvAZelfd8taa+kZyRdK2lW2fZC0oclPSHpp5I+JelkSf831bupvH3Fc6xaq6TZkkrADOBBSU+Oc6hWAH8LDKXp8u3Pk3RXquubkq6TdEvZ+jNSnXslPXiwD5UdbBz0Nt18DvhcRLwJOBnYlJaflX4elYY37gYuS48+4O1AF3AtgKSFwOeBi4HjgSOBEyr2tRTYDBwF3Aq8CnwEOBY4Ezgb+HBFn3OBRcAZwJ8D64FLgLnAbwIDNZ5X1Voj4hcR0ZXa/LOIOLlaZ0mHA8tSnbcCyyveVL4E3AccA1wJXFrW9wTg68CngTnAx4AvSzquRq2WGQe9TYWvpivJvZL2UgRwLb8ETpF0bESUIuKecdpeDFwdEbsiogR8nCIAZ1KE4v+JiG9HxD5gDVB5Y6e7I+KrEfFaRPw8IrZHxD0RsT8ingK+CPxuRZ/PRMRLETECPALcnvb/E+AbQK0PUsertRF/CPwCuJ0itA8F/gWApLcCpwNrImJfRHwb2FLW9xJgKCKG0nO9A7gfOL/BfdvrnIPepsL7I+KoAw9+/Sq53OXAAuBxSdsk/f44bd8C/KBs/gfATKA7rXv6wIqIeBl4oaL/0+UzkhZI+pqkH6fhnP9CcXVf7tmy6Z9Xme+iuvFqbcQKYFN6E3oF+DL/MHzzFmAsPccDyp/b24ALK95s30Xxm44dBDr9AZTZPxIRTwADkg6huIrdLOkYfv1qHOBHFCF2wFuB/RTh+wzwTw+skPQGimGNf7S7ivm/Ar4HDETETyX9GcVvBq0wXq3jknQi8B5gsaR/mRYfDhwm6ViK5zpH0uFlYT+3bBNPAzdHxL9p8jnY65Sv6G1akXSJpOMi4jVgb1r8GvD/0s+3lzXfAHwkfRDZRXEFvjEi9lOMvb9P0u+ksewrAdXZ/RuBl4CSpN8APtSip1Wv1nouBXZSvHGdlh4LgN0Ub0o/oBiKuVLSLElnAu8r638LxbE4V9IMSYdJ6k1vIHYQcNDbdNMPjKRvonwOWJ7Gz18G/jPwnTT8cAZwA3AzxTdyvg+8AvwpQBpD/1NgkOKKtwQ8RzHOXcvHgIuAnwJ/DWxs4fOqWWsDVgCfj4gflz+AL/APwzcXU3yA/ALFh64bSc81Ip6m+OD5P1K8YT4N/Hv8+j9oyP/xiB0M0lX0XmB+RHy/w+W0naSNwOMR8clO12Kd53d0y5ak90k6XNIRFH95+jDwVGerag9Jp6fv9B8iqZ/iCv6rHS7LpgkHveVsKcWHoD8C5lMMA+X6K+w/AYYphqj+O/ChWrdSsIOPh27MzDLnK3ozs8xNu+/RH3vssXHSSSd1bP8/+9nPOOKIIzq2/3pcX3NcX3NcX3PaWd/27dufj4jqt7WIiGn1WLRoUXTS1q1bO7r/elxfc1xfc1xfc9pZH3B/1MhVD92YmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWVu2t0Cwczy0dc3tfsbGIC1a6d2nxNRr76tW9uzX1/Rm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpa5hoJeUr+kHZJGJa2usv4sSd+VtF/Ssirr3yRpt6RrW1G0mZk1rm7QS5oBXAecBywEBiQtrGj2Q+Ay4Es1NvMp4K7Jl2lmZpPVyBX9YmA0InZFxD5gEFha3iAinoqIh4DXKjtLWgR0A7e3oF4zM5sgFf/V4DgNiqGY/oi4Is1fCiyJiFVV2t4IfC0iNqf5Q4A7gUuA9wI9NfqtBFYCdHd3LxocHGzmOTWlVCrR1dXVsf3X4/qa4/qaM9H6du5sYzFVzJlTYmxs+h6/evUtWDD5bff19W2PiJ5q69p9C4QPA0MRsVtSzUYRsR5YD9DT0xO9vb1tLqu24eFhOrn/elxfc1xfcyZa31TfjmBgYJgNG3qndqcTUK++dt0CoZGg3wPMLZs/MS1rxJnAuyV9GOgCZkkqRcSvfaBrZmbt0UjQbwPmS5pHEfDLgYsa2XhEXHxgWtJlFEM3DnkzsylU98PYiNgPrAJuAx4DNkXEiKR1ki4AkHS6pN3AhcAXJY20s2gzM2tcQ2P0ETEEDFUsW1M2vY1iSGe8bdwI3DjhCs3MrCn+y1gzs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLXENBL6lf0g5Jo5JWV1l/lqTvStovaVnZ8tMk3S1pRNJDkv5VK4s3M7P66ga9pBnAdcB5wEJgQNLCimY/BC4DvlSx/GXgAxFxKtAP/KWko5qs2czMJmBmA20WA6MRsQtA0iCwFHj0QIOIeCqte628Y0TsLJv+kaTngOOAvc0WbmZmjVFEjN+gGIrpj4gr0vylwJKIWFWl7Y3A1yJic5V1i4GbgFMj4rWKdSuBlQDd3d2LBgcHJ/dsWqBUKtHV1dWx/dfj+prj+poz0fp27qzfppXmzCkxNjZ9j1+9+hYsmPy2+/r6tkdET7V1jVzRN03S8cDNwIrKkAeIiPXAeoCenp7o7e2dirKqGh4eppP7r8f1Ncf1NWei9a1d275aqhkYGGbDht6p3ekE1Ktv69b27LeRD2P3AHPL5k9Myxoi6U3A14FPRMQ9EyvPzMya1UjQbwPmS5onaRawHNjSyMZT+68Af1NtOMfMzNqvbtBHxH5gFXAb8BiwKSJGJK2TdAGApNMl7QYuBL4oaSR1/yPgLOAySQ+kx2nteCJmZlZdQ2P0ETEEDFUsW1M2vY1iSKey3y3ALU3WaGZmTfBfxpqZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmGgp6Sf2SdkgalbS6yvqzJH1X0n5JyyrWrZD0RHqsaFXhZmbWmLpBL2kGcB1wHrAQGJC0sKLZD4HLgC9V9J0DfBJYAiwGPinp6ObLNjOzRjVyRb8YGI2IXRGxDxgElpY3iIinIuIh4LWKvucCd0TEWES8CNwB9LegbjMza5AiYvwGxVBMf0RckeYvBZZExKoqbW8EvhYRm9P8x4DDIuLTaf4vgJ9HxGcr+q0EVgJ0d3cvGhwcbPZ5TVqpVKKrq6tj+6/H9TXH9TVnovXt3NnGYqqYM6fE2Nj0PX716luwYPLb7uvr2x4RPdXWzZz8ZlsnItYD6wF6enqit7e3Y7UMDw/Tyf3X4/qa4/qaM9H61q5tXy3VDAwMs2FD79TudALq1bd1a3v228jQzR5gbtn8iWlZI5rpa2ZmLdBI0G8D5kuaJ2kWsBzY0uD2bwPOkXR0+hD2nLTMzMymSN2gj4j9wCqKgH4M2BQ
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<Figure size 10000x10000 with 0 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"n, bins, patches = plt.hist(data['age'], 4, density=True, facecolor='b', alpha=0.75)\n",
|
||
|
"\n",
|
||
|
"plt.title('Histogram of Age')\n",
|
||
|
"plt.grid(True)\n",
|
||
|
"plt.figure(figsize=(100,100), dpi=100)\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>id</th>\n",
|
||
|
" <th>gender</th>\n",
|
||
|
" <th>age</th>\n",
|
||
|
" <th>topic</th>\n",
|
||
|
" <th>sign</th>\n",
|
||
|
" <th>date</th>\n",
|
||
|
" <th>text</th>\n",
|
||
|
" <th>label</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>14,May,2004</td>\n",
|
||
|
" <td>Info has been found (+/- 100 pages,...</td>\n",
|
||
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>13,May,2004</td>\n",
|
||
|
" <td>These are the team members: Drewe...</td>\n",
|
||
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>12,May,2004</td>\n",
|
||
|
" <td>In het kader van kernfusie op aarde...</td>\n",
|
||
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>2059027</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>15</td>\n",
|
||
|
" <td>Student</td>\n",
|
||
|
" <td>Leo</td>\n",
|
||
|
" <td>12,May,2004</td>\n",
|
||
|
" <td>testing!!! testing!!!</td>\n",
|
||
|
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>3581210</td>\n",
|
||
|
" <td>male</td>\n",
|
||
|
" <td>33</td>\n",
|
||
|
" <td>InvestmentBanking</td>\n",
|
||
|
" <td>Aquarius</td>\n",
|
||
|
" <td>11,June,2004</td>\n",
|
||
|
" <td>Thanks to Yahoo!'s Toolbar I can ...</td>\n",
|
||
|
" <td>[0.0, 0.0, 1.0, 0.0]</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" id gender age topic sign date \\\n",
|
||
|
"0 2059027 male 15 Student Leo 14,May,2004 \n",
|
||
|
"1 2059027 male 15 Student Leo 13,May,2004 \n",
|
||
|
"2 2059027 male 15 Student Leo 12,May,2004 \n",
|
||
|
"3 2059027 male 15 Student Leo 12,May,2004 \n",
|
||
|
"4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n",
|
||
|
"\n",
|
||
|
" text label \n",
|
||
|
"0 Info has been found (+/- 100 pages,... [1.0, 0.0, 0.0, 0.0] \n",
|
||
|
"1 These are the team members: Drewe... [1.0, 0.0, 0.0, 0.0] \n",
|
||
|
"2 In het kader van kernfusie op aarde... [1.0, 0.0, 0.0, 0.0] \n",
|
||
|
"3 testing!!! testing!!! [1.0, 0.0, 0.0, 0.0] \n",
|
||
|
"4 Thanks to Yahoo!'s Toolbar I can ... [0.0, 0.0, 1.0, 0.0] "
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"\"\"\"\n",
|
||
|
"1 - 22 -> 1 klasa\n",
|
||
|
"23 - 31 -> 2 klasa\n",
|
||
|
"32 - 39 -> 3 klasa \n",
|
||
|
"40 - 48 -> 4 klasa\n",
|
||
|
"\"\"\"\n",
|
||
|
"\n",
|
||
|
"def mapAgeToClass(value: pd.DataFrame) -> int:\n",
|
||
|
" if(value['age'] <=22):\n",
|
||
|
" return 1\n",
|
||
|
" elif(value['age'] > 22 and value['age'] <= 31):\n",
|
||
|
" return 2\n",
|
||
|
" elif(value['age'] > 31 and value['age'] <= 39):\n",
|
||
|
" return 3\n",
|
||
|
" else:\n",
|
||
|
" return 4\n",
|
||
|
"\n",
|
||
|
"def mapAgeToClass2(value: pd.DataFrame) -> int:\n",
|
||
|
" if(value['age'] <=22):\n",
|
||
|
" return [1.0,0.0,0.0,0.0]\n",
|
||
|
" elif(value['age'] > 22 and value['age'] <= 31):\n",
|
||
|
" return [0.0,1.0,0.0,0.0]\n",
|
||
|
" elif(value['age'] > 31 and value['age'] <= 39):\n",
|
||
|
" return [0.0,0.0,1.0,0.0]\n",
|
||
|
" else:\n",
|
||
|
" return [0.0,0.0,0.0,1.0]\n",
|
||
|
" \n",
|
||
|
"data['label'] = data.apply(lambda row: mapAgeToClass2(row), axis=1)\n",
|
||
|
"data.head()\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"X = list(data['text'])\n",
|
||
|
"Y = list(data['label'])\n",
|
||
|
"if (torch.cuda.is_available()):\n",
|
||
|
" device = \"cuda:0\"\n",
|
||
|
" torch.cuda.empty_cache()\n",
|
||
|
"else:\n",
|
||
|
" device = \"cpu\"\n",
|
||
|
"device = \"cpu\"\n",
|
||
|
"\n",
|
||
|
"# model = model.to(device)\n",
|
||
|
"\n",
|
||
|
"X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.2)\n",
|
||
|
"X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)\n",
|
||
|
"# .to(device)\n",
|
||
|
"X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)\n",
|
||
|
"# .to(device)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Dataset(torch.utils.data.Dataset):\n",
|
||
|
" def __init__(self, encodings, labels=None):\n",
|
||
|
" self.encodings = encodings\n",
|
||
|
" self.labels = labels\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, idx):\n",
|
||
|
" item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
|
||
|
" if self.labels:\n",
|
||
|
" item[\"labels\"] = torch.tensor(self.labels[idx])\n",
|
||
|
" return item\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.encodings[\"input_ids\"])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_dataset = Dataset(X_train_tokenized, y_train)\n",
|
||
|
"val_dataset = Dataset(X_val_tokenized, y_val)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def compute_metrics(p):\n",
|
||
|
" pred, labels = p\n",
|
||
|
" pred = np.argmax(pred, axis=1)\n",
|
||
|
"\n",
|
||
|
" accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
|
||
|
" recall = recall_score(y_true=labels, y_pred=pred)\n",
|
||
|
" precision = precision_score(y_true=labels, y_pred=pred)\n",
|
||
|
" f1 = f1_score(y_true=labels, y_pred=pred)\n",
|
||
|
"\n",
|
||
|
" return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"PyTorch: setting up devices\n",
|
||
|
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"args = TrainingArguments(\n",
|
||
|
" output_dir=\"output\",\n",
|
||
|
" evaluation_strategy=\"steps\",\n",
|
||
|
" eval_steps=500,\n",
|
||
|
" per_device_train_batch_size=8,\n",
|
||
|
" per_device_eval_batch_size=8,\n",
|
||
|
" num_train_epochs=3,\n",
|
||
|
" seed=0,\n",
|
||
|
" load_best_model_at_end=True,\n",
|
||
|
" no_cuda=True\n",
|
||
|
")\n",
|
||
|
"trainer = Trainer(\n",
|
||
|
" model=model,\n",
|
||
|
" args=args,\n",
|
||
|
" train_dataset=train_dataset,\n",
|
||
|
" eval_dataset=val_dataset,\n",
|
||
|
" compute_metrics=compute_metrics,\n",
|
||
|
" callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/ramon/projects/projekt_glebokie/venv/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
||
|
" FutureWarning,\n",
|
||
|
"***** Running training *****\n",
|
||
|
" Num examples = 80\n",
|
||
|
" Num Epochs = 3\n",
|
||
|
" Instantaneous batch size per device = 8\n",
|
||
|
" Total train batch size (w. parallel, distributed & accumulation) = 8\n",
|
||
|
" Gradient Accumulation steps = 1\n",
|
||
|
" Total optimization steps = 30\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"trainer.train()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"raw_pred, _, _ = trainer.predict(val_dataset)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_pred = np.argmax(raw_pred, axis=1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Model typu decoder"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"interpreter": {
|
||
|
"hash": "f4394274b6de412f99b9d08dfb473204abc12afd5637ebb20c9ad8dbd67e97a0"
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3.10.1 64-bit ('venv': venv)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.12"
|
||
|
},
|
||
|
"orig_nbformat": 4
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|