From e847befae951587356483614b5e5f9ac5dbb362a Mon Sep 17 00:00:00 2001 From: wangobango Date: Sun, 30 Jan 2022 16:54:18 +0100 Subject: [PATCH] progress --- .gitignore | 3 + main.ipynb | 538 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 541 insertions(+) create mode 100644 .gitignore create mode 100644 main.ipynb diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1679c20 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +data/* +*.csv +venv/* \ No newline at end of file diff --git a/main.ipynb b/main.ipynb new file mode 100644 index 0000000..c739032 --- /dev/null +++ b/main.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n", + "import torch\n", + "from transformers import TrainingArguments, Trainer\n", + "from transformers import BertTokenizer, BertForSequenceClassification\n", + "from transformers import EarlyStoppingCallback\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idgenderagetopicsigndatetext
02059027male15StudentLeo14,May,2004Info has been found (+/- 100 pages,...
12059027male15StudentLeo13,May,2004These are the team members: Drewe...
22059027male15StudentLeo12,May,2004In het kader van kernfusie op aarde...
32059027male15StudentLeo12,May,2004testing!!! testing!!!
43581210male33InvestmentBankingAquarius11,June,2004Thanks to Yahoo!'s Toolbar I can ...
\n", + "
" + ], + "text/plain": [ + " id gender age topic sign date \\\n", + "0 2059027 male 15 Student Leo 14,May,2004 \n", + "1 2059027 male 15 Student Leo 13,May,2004 \n", + "2 2059027 male 15 Student Leo 12,May,2004 \n", + "3 2059027 male 15 Student Leo 12,May,2004 \n", + "4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n", + "\n", + " text \n", + "0 Info has been found (+/- 100 pages,... \n", + "1 These are the team members: Drewe... \n", + "2 In het kader van kernfusie op aarde... \n", + "3 testing!!! testing!!! \n", + "4 Thanks to Yahoo!'s Toolbar I can ... " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = pd.read_csv(\"data/blogtext.csv\")\n", + "data = data[:100]\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model typu encoder (BertForSequenceClassification)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']\n", + "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model_name = 'bert-base-uncased'\n", + "tokenizer = BertTokenizer.from_pretrained(model_name)\n", + "model = BertForSequenceClassification.from_pretrained(model_name, problem_type=\"multi_label_classification\", num_labels=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbqElEQVR4nO3df5RcZZ3n8feHhAShFQhgL0KUCMnOCeMsY5oEZpXpFgcadiTObJhN88OwA5tVJ3N29Lizcd2JJLp7Dq4L4x5wNLOwMIDpZOPqZLUdQEkfRpcfIcqvBhKaiJKIsNBELBFj4Lt/3CdOTVnVVd1V1dU8+bzOqdP3x/Pc+62bW5+6/VT1jSICMzPL1yGdLsDMzNrLQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvXWcpBFJvZ2uo5Mk/YGkpyWVJP12p+uxvDjora0kPSXpvRXLLpP07QPzEXFqRAzX2c5JkkLSzDaV2mmfBVZFRFdEfK9aAxV2SXp0imuz1zkHvRkwDd5A3gaM1GlzFvBm4O2STm9/SZYLB711XPlVv6TFku6X9JKkZyVdnZrdlX7uTcMbZ0o6RNJ/kvQDSc9J+htJR5Zt9wNp3QuS/qJiP1dK2izpFkkvAZelfd8taa+kZyRdK2lW2fZC0oclPSHpp5I+JelkSf831bupvH3Fc6xaq6TZkkrADOBBSU+Oc6hWAH8LDKXp8u3Pk3RXquubkq6TdEvZ+jNSnXslPXiwD5UdbBz0Nt18DvhcRLwJOBnYlJaflX4elYY37gYuS48+4O1AF3AtgKSFwOeBi4HjgSOBEyr2tRTYDBwF3Aq8CnwEOBY4Ezgb+HBFn3OBRcAZwJ8D64FLgLnAbwIDNZ5X1Voj4hcR0ZXa/LOIOLlaZ0mHA8tSnbcCyyveVL4E3AccA1wJXFrW9wTg68CngTnAx4AvSzquRq2WGQe9TYWvpivJvZL2UgRwLb8ETpF0bESUIuKecdpeDFwdEbsiogR8nCIAZ1KE4v+JiG9HxD5gDVB5Y6e7I+KrEfFaRPw8IrZHxD0RsT8ingK+CPxuRZ/PRMRLETECPALcnvb/E+AbQK0PUsertRF/CPwCuJ0itA8F/gWApLcCpwNrImJfRHwb2FLW9xJgKCKG0nO9A7gfOL/BfdvrnIPepsL7I+KoAw9+/Sq53OXAAuBxSdsk/f44bd8C/KBs/gfATKA7rXv6wIqIeBl4oaL/0+UzkhZI+pqkH6fhnP9CcXVf7tmy6Z9Xme+iuvFqbcQKYFN6E3oF+DL/MHzzFmAsPccDyp/b24ALK95s30Xxm44dBDr9AZTZPxIRTwADkg6huIrdLOkYfv1qHOBHFCF2wFuB/RTh+wzwTw+skPQGimGNf7S7ivm/Ar4HDETETyX9GcVvBq0wXq3jknQi8B5gsaR/mRYfDhwm6ViK5zpH0uFlYT+3bBNPAzdHxL9p8jnY65Sv6G1akXSJpOMi4jVgb1r8GvD/0s+3lzXfAHwkfRDZRXEFvjEi9lOMvb9P0u+ksewrAdXZ/RuBl4CSpN8APtSip1Wv1nouBXZSvHGdlh4LgN0Ub0o/oBiKuVLSLElnAu8r638LxbE4V9IMSYdJ6k1vIHYQcNDbdNMPjKRvonwOWJ7Gz18G/jPwnTT8cAZwA3AzxTdyvg+8AvwpQBpD/1NgkOKKtwQ8RzHOXcvHgIuAnwJ/DWxs4fOqWWsDVgCfj4gflz+AL/APwzcXU3yA/ALFh64bSc81Ip6m+OD5P1K8YT4N/Hv8+j9oyP/xiB0M0lX0XmB+RHy/w+W0naSNwOMR8clO12Kd53d0y5ak90k6XNIRFH95+jDwVGerag9Jp6fv9B8iqZ/iCv6rHS7LpgkHveVsKcWHoD8C5lMMA+X6K+w/AYYphqj+O/ChWrdSsIOPh27MzDLnK3ozs8xNu+/RH3vssXHSSSd1bP8/+9nPOOKIIzq2/3pcX3NcX3NcX3PaWd/27dufj4jqt7WIiGn1WLRoUXTS1q1bO7r/elxfc1xfc1xfc9pZH3B/1MhVD92YmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWVu2t0Cwczy0dc3tfsbGIC1a6d2nxNRr76tW9uzX1/Rm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpa5hoJeUr+kHZJGJa2usv4sSd+VtF/Ssirr3yRpt6RrW1G0mZk1rm7QS5oBXAecBywEBiQtrGj2Q+Ay4Es1NvMp4K7Jl2lmZpPVyBX9YmA0InZFxD5gEFha3iAinoqIh4DXKjtLWgR0A7e3oF4zM5sgFf/V4DgNiqGY/oi4Is1fCiyJiFVV2t4IfC0iNqf5Q4A7gUuA9wI9NfqtBFYCdHd3LxocHGzmOTWlVCrR1dXVsf3X4/qa4/qaM9H6du5sYzFVzJlTYmxs+h6/evUtWDD5bff19W2PiJ5q69p9C4QPA0MRsVtSzUYRsR5YD9DT0xO9vb1tLqu24eFhOrn/elxfc1xfcyZa31TfjmBgYJgNG3qndqcTUK++dt0CoZGg3wPMLZs/MS1rxJnAuyV9GOgCZkkqRcSvfaBrZmbt0UjQbwPmS5pHEfDLgYsa2XhEXHxgWtJlFEM3DnkzsylU98PYiNgPrAJuAx4DNkXEiKR1ki4AkHS6pN3AhcAXJY20s2gzM2tcQ2P0ETEEDFUsW1M2vY1iSGe8bdwI3DjhCs3MrCn+y1gzs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLXENBL6lf0g5Jo5JWV1l/lqTvStovaVnZ8tMk3S1pRNJDkv5VK4s3M7P66ga9pBnAdcB5wEJgQNLCimY/BC4DvlSx/GXgAxFxKtAP/KWko5qs2czMJmBmA20WA6MRsQtA0iCwFHj0QIOIeCqte628Y0TsLJv+kaTngOOAvc0WbmZmjVFEjN+gGIrpj4gr0vylwJKIWFWl7Y3A1yJic5V1i4GbgFMj4rWKdSuBlQDd3d2LBgcHJ/dsWqBUKtHV1dWx/dfj+prj+poz0fp27qzfppXmzCkxNjZ9j1+9+hYsmPy2+/r6tkdET7V1jVzRN03S8cDNwIrKkAeIiPXAeoCenp7o7e2dirKqGh4eppP7r8f1Ncf1NWei9a1d275aqhkYGGbDht6p3ekE1Ktv69b27LeRD2P3AHPL5k9Myxoi6U3A14FPRMQ9EyvPzMya1UjQbwPmS5onaRawHNjSyMZT+68Af1NtOMfMzNqvbtBHxH5gFXAb8BiwKSJGJK2TdAGApNMl7QYuBL4oaSR1/yPgLOAySQ+kx2nteCJmZlZdQ2P0ETEEDFUsW1M2vY1iSKey3y3ALU3WaGZmTfBfxpqZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmGgp6Sf2SdkgalbS6yvqzJH1X0n5JyyrWrZD0RHqsaFXhZmbWmLpBL2kGcB1wHrAQGJC0sKLZD4HLgC9V9J0DfBJYAiwGPinp6ObLNjOzRjVyRb8YGI2IXRGxDxgElpY3iIinIuIh4LWKvucCd0TEWES8CNwB9LegbjMza5AiYvwGxVBMf0RckeYvBZZExKoqbW8EvhYRm9P8x4DDIuLTaf4vgJ9HxGcr+q0EVgJ0d3cvGhwcbPZ5TVqpVKKrq6tj+6/H9TXH9TVnovXt3NnGYqqYM6fE2Nj0PX716luwYPLb7uvr2x4RPdXWzZz8ZlsnItYD6wF6enqit7e3Y7UMDw/Tyf3X4/qa4/qaM9H61q5tXy3VDAwMs2FD79TudALq1bd1a3v228jQzR5gbtn8iWlZI5rpa2ZmLdBI0G8D5kuaJ2kWsBzY0uD2bwPOkXR0+hD2nLTMzMymSN2gj4j9wCqKgH4M2BQRI5LWSboAQNLpknYDFwJflDSS+o4Bn6J4s9gGrEvLzMxsijQ0Rh8RQ8BQxbI1ZdPbKIZlqvW9AbihiRrNzKwJ/stYM7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy1xDQS+pX9IOSaOSVldZP1vSxrT+XkknpeWHSrpJ0sOSHpP08RbXb2ZmddQNekkzgOuA84CFwICkhRXNLgdejIhTgGuAq9LyC4HZEfEOYBHwbw+8CZiZ2dRo5Ip+MTAaEbsiYh8wCCytaLMUuClNbwbOliQggCMkzQTeAOwDXmpJ5WZm1hBFxPgNpGVAf0RckeYvBZZExKqyNo+kNrvT/JPAEuAnwM3A2cDhwEciYn2VfawEVgJ0d3cvGhwcbMFTm5xSqURXV1fH9l+P62uO62vOROvbubONxVQxZ06JsbHpe/zq1bdgweS33dfXtz0ieqqtmzn5zTZkMfAq8BbgaODvJX0zInaVN0rhvx6gp6cnent721xWbcPDw3Ry//W4vua4vuZMtL61a9tXSzUDA8Ns2NA7tTudgHr1bd3anv02MnSzB5hbNn9iWla1TRqmORJ4AbgI+LuI+GVEPAd8B6j6jmNmZu3RSNBvA+ZLmidpFrAc2FLRZguwIk0vA+6MYkzoh8B7ACQdAZwBPN6Kws3MrDF1gz4i9gOrgNuAx4BNETEiaZ2kC1Kz64FjJI0CHwUOfAXzOqBL0gjFG8b/jIiHWv0kzMystobG6CNiCBiqWLambPoViq9SVvYrVVtuZmZTx38Za2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplrKOgl9UvaIWlU0uoq62dL2pjW3yvppLJ1vyXpbkkjkh6WdFgL6zczszrqBr2kGcB1wHnAQmBA0sKKZpcDL0bEKcA1wFWp70zgFuCDEXEq0Av8smXVm5lZXY1c0S8GRiNiV0TsAwaBpRVtlgI3penNwNmSBJwDPBQRDwJExAsR8WprSjczs0YoIsZvIC0D+iPiijR/KbAkIlaVtXkktdmd5p8ElgCXAIuANwPHAYMR8Zkq+1gJrATo7u5eNDg42IKnNjmlUomurq6O7b8e19cc19ecida3c2cbi6lizpwSY2PT9/jVq2/Bgslvu6+vb3tE9FRbN3Pym23ITOBdwOnAy8C3JG2PiG+VN4qI9cB6gJ6enujt7W1zWbUNDw/Tyf3X4/qa4/qaM9H61q5tXy3VDAwMs2FD79TudALq1bd1a3v220jQ7wHmls2fmJZVa7M7jcsfCbwA7AbuiojnASQNAe8EvkWb9PU1139gYOpPzolodX3tOrHMbPpoZIx+GzBf0jxJs4DlwJaKNluAFWl6GXBnFGNCtwHvkHR4egP4XeDR1pRuZmaNqHtFHxH7Ja2iCO0ZwA0RMSJpHXB/RGwBrgduljQKjFG8GRARL0q6muLNIoChiPh6m56LmZlV0dAYfUQMAUMVy9aUTb8CXFij7y0UX7E0M7MO8F/GmpllzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpY5B72ZWeYaCnpJ/ZJ2SBqVtLrK+tmSNqb190o6qWL9WyWVJH2sRXWbmVmD6ga9pBnAdcB5wEJgQNLCimaXAy9GxCnANcBVFeuvBr7RfLlmZjZRjVzRLwZGI2JXROwDBoGlFW2WAjel6c3A2ZIEIOn9wPeBkZZUbGZmE9JI0J8APF02vzstq9omIvYDPwGOkdQF/AdgbfOlmpnZZCgixm8gLQP6I+KKNH8psCQiVpW1eSS12Z3mnwSWAKuB+yJik6QrgVJEfLbKPlYCKwG6u7sXDQ4OTvoJ7dw56a4AzJlTYmysq7mNtFGr61uwoGWbAqBUKtHVNX2Pn+trzkTra/b1OFGv99dvM6/Hvr6+7RHRU23dzAb67wHmls2fmJZVa7Nb0kzgSOAFirBfJukzwFHAa5JeiYhryztHxHpgPUBPT0/09vY2UFZ1a5v83WFgYJgNGya//3ZrdX1bt7ZsUwAMDw/TzL9fu7m+5ky0vmZfjxP1en/9tvr1eEAjQb8NmC9pHkWgLwcuqmizBVgB3A0sA+6M4leFdx9oUHZFfy1mZjZl6gZ9ROyXtAq4DZgB3BARI5LWAfdHxBbgeuBmSaPAGMWbgZmZTQONXNETEUPAUMWyNWXTrwAX1tnGlZOoz8zMmuS/jDUzy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLXUNBL6pe0Q9KopNVV1s+WtDGtv1fSSWn570naLunh9PM9La7fzMzqqBv0kmYA1wHnAQuBAUkLK5pdDrwYEacA1wBXpeXPA++LiHcAK4CbW1W4mZk1ppEr+sXAaETsioh9wCCwtKLNUuCmNL0ZOFuSIuJ7EfGjtHwEeIOk2a0o3MzMGqOIGL+BtAzoj4gr0vylwJKIWFXW5pHUZneafzK1eb5iOx+MiPdW2cdKYCVAd3f3osHBwUk/oZ07J90VgDlzSoyNdTW3kTZqdX0LFrRsUwCUSiW6uqbv8XN9zZlofc2+Hifq9f76beb12NfXtz0ieqqtmzn5zTZO0qkUwznnVFsfEeuB9QA9PT3R29s76X2tXTvprgAMDAyzYcPk999ura5v69aWbQqA4eFhmvn3azfX15yJ1tfs63GiXu+v31a/Hg9oZOhmDzC3bP7EtKxqG0kzgSOBF9L8icBXgA9ExJPNFmxmZhPTSNBvA+ZLmidpFrAc2FLRZgvFh60Ay4A7IyIkHQV8HVgdEd9pUc1mZjYBdYM+IvYDq4DbgMeATRExImmdpAtSs+uBYySNAh8FDnwFcxVwCrBG0gPp8eaWPwszM6upoTH6iBgChiqWrSmbfgW4sEq/TwOfbrJGMzNrgv8y1swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy9yU/OfgZrno62vv9gcGpv4/1J6I6V6fVecrejOzzDnozcwy11DQS+qXtEPSqKTVVdbPlrQxrb9X0kll6z6elu+QdG4LazczswbUDXpJM4DrgPOAhcCApIUVzS4HXoyIU4BrgKtS34XAcuBUoB/4fNqemZlNkUau6BcDoxGxKyL2AYPA0oo2S4Gb0vRm4GxJSssHI+IXEfF9YDRtz8zMpkgj37o5AXi6bH43sKRWm4jYL+knwDFp+T0VfU+o3IGklcDKNFuStKOh6ttgeJhjgec7tf96Wl2f1Kot/cq0Pn5M8/oOtvOv1V7v9TX5enxbrRXT4uuVEbEeWN/pOgAk3R8RPZ2uoxbX1xzX1xzX15xO1dfI0M0eYG7Z/IlpWdU2kmYCRwIvNNjXzMzaqJGg3wbMlzRP0iyKD1e3VLTZAqxI08uAOyMi0vLl6Vs584D5wH2tKd3MzBpRd+gmjbmvAm4DZgA3RMSIpHXA/RGxBbgeuFnSKDBG8WZAarcJeBTYD/xJRLzapufSKtNiCGkcrq85rq85rq85HalPxYW3mZnlyn8Za2aWOQe9mVnmDpqgl3SDpOckPVK27EpJeyQ9kB7n1+g77i0g2ljfxrLanpL0QI2+T0l6OLW7vw21zZW0VdKjkkYk/bu0fI6kOyQ9kX4eXaP/itTmCUkrqrVpU33/VdLjkh6S9BVJR9Xo39bjV6fGjp+D49Q2Lc6/tI/DJN0n6cFU49q0fF667cpoqndWjf5tvRXLOPXdmvb5SHqNH1qj/6tlx7ryyy7Ni4iD4gGcBbwTeKRs2ZXAx+r0mwE8CbwdmAU8CCycivoq1v83YE2NdU8Bx7bx2B0PvDNNvxHYSXE7jM8Aq9Py1cBVVfrOAXaln0en6aOnqL5zgJlp+VXV6puK41enxo6fg7Vqmy7nX9qHgK40fShwL3AGsAlYnpZ/AfhQlb4L0zGbDcxLx3LGFNV3flonYEO1+lKfUjuP30FzRR8Rd1F8I2iiGrkFRNPGq0+SgD+iOFGmXEQ8ExHfTdM/BR6j+Avn8ltf3AS8v0r3c4E7ImIsIl4E7qC471Hb64uI2yNif2p2D8XfcXTEOMewEW09B+vV1unzL9UVEVFKs4emRwDvobjtCtQ+B9t+K5Za9UXEUFoXFF8t78g5eNAE/ThWpV/tb6gx9FDtFhCNvkBb5d3AsxHxRI31AdwuabuK20m0jYo7k/42xRVLd0Q8k1b9GOiu0mVKj19FfeX+GPhGjW5Tdvygao3T5hyscfymxfknaUYaPnqO4oLhSWBv2Zt5reMyJcevsr6IuLds3aHApcDf1eh+mKT7Jd0j6f2tru1gD/q/Ak4GTgOeofj1dDoaYPyrqXdFxDsp7jD6J5LOakcRkrqALwN/FhEvla9LVywd/a5urfokfYLi7zhurdF1So5fjRqnzTk4zr/vtDj/IuLViDiN4qp4MfAb7djPZFXWJ+k3y1Z/HrgrIv6+Rve3RXFrhIuAv5R0citrO6iDPiKeTf84rwF/TfVf5zp6GwcVt5T4Q2BjrTYRsSf9fA74Cm24Q2i6IvkycGtE/O+0+FlJx6f1x1NcyVSakuNXoz4kXQb8PnBxejP6NVNx/GrVOF3OwXGO37Q4/yr2txfYCpwJHJVqhNrHZUpfw2X19QNI+iRwHPDRcfocOIa7gGGK36pa5qAO+gMhlfwB8EiVZo3cAqKd3gs8HhG7q62UdISkNx6YpvgAstrzmLQ0Rns98FhEXF22qvzWFyuAv63S/TbgHElHp2GJc9KyttcnqR/4c+CCiHi5Rt+2H786NXb8HBzn3xemwfmXtn2c0remJL0B+D2KzxK2Utx2BWqfg22/FUuN+h6XdAXF51QD6c28Wt+jJc1O08cC/5zibgKt085PeqfTg+JXz2eAX1KM0V0O3Aw8DDxEcTIcn9q+BRgq63s+xTcRngQ+MVX1peU3Ah+saPur+ii+ifFgeoy0oz7gXRTDMg8BD6TH+RS3ov4W8ATwTWBOat8D/I+y/n9M8QHYKPCvp7C+UYqx2QPLvtCJ41enxo6fg7Vqmy7nX9rPbwHfSzU+QvoGUNr/fenf+n8Bs9PyC4B1Zf0/kY7dDuC8Kaxvf9rvgeN6YPmvXiPA76Rz4MH08/JW1+dbIJiZZe6gHroxMzsYOOjNzDLnoDczy5yD3swscw56M7PMOejNzDLnoDczy9z/B8C5FPdY0UE0AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "n, bins, patches = plt.hist(data['age'], 4, density=True, facecolor='b', alpha=0.75)\n", + "\n", + "plt.title('Histogram of Age')\n", + "plt.grid(True)\n", + "plt.figure(figsize=(100,100), dpi=100)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idgenderagetopicsigndatetextlabel
02059027male15StudentLeo14,May,2004Info has been found (+/- 100 pages,...[1.0, 0.0, 0.0, 0.0]
12059027male15StudentLeo13,May,2004These are the team members: Drewe...[1.0, 0.0, 0.0, 0.0]
22059027male15StudentLeo12,May,2004In het kader van kernfusie op aarde...[1.0, 0.0, 0.0, 0.0]
32059027male15StudentLeo12,May,2004testing!!! testing!!![1.0, 0.0, 0.0, 0.0]
43581210male33InvestmentBankingAquarius11,June,2004Thanks to Yahoo!'s Toolbar I can ...[0.0, 0.0, 1.0, 0.0]
\n", + "
" + ], + "text/plain": [ + " id gender age topic sign date \\\n", + "0 2059027 male 15 Student Leo 14,May,2004 \n", + "1 2059027 male 15 Student Leo 13,May,2004 \n", + "2 2059027 male 15 Student Leo 12,May,2004 \n", + "3 2059027 male 15 Student Leo 12,May,2004 \n", + "4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n", + "\n", + " text label \n", + "0 Info has been found (+/- 100 pages,... [1.0, 0.0, 0.0, 0.0] \n", + "1 These are the team members: Drewe... [1.0, 0.0, 0.0, 0.0] \n", + "2 In het kader van kernfusie op aarde... [1.0, 0.0, 0.0, 0.0] \n", + "3 testing!!! testing!!! [1.0, 0.0, 0.0, 0.0] \n", + "4 Thanks to Yahoo!'s Toolbar I can ... [0.0, 0.0, 1.0, 0.0] " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\"\n", + "1 - 22 -> 1 klasa\n", + "23 - 31 -> 2 klasa\n", + "32 - 39 -> 3 klasa \n", + "40 - 48 -> 4 klasa\n", + "\"\"\"\n", + "\n", + "def mapAgeToClass(value: pd.DataFrame) -> int:\n", + " if(value['age'] <=22):\n", + " return 1\n", + " elif(value['age'] > 22 and value['age'] <= 31):\n", + " return 2\n", + " elif(value['age'] > 31 and value['age'] <= 39):\n", + " return 3\n", + " else:\n", + " return 4\n", + "\n", + "def mapAgeToClass2(value: pd.DataFrame) -> int:\n", + " if(value['age'] <=22):\n", + " return [1.0,0.0,0.0,0.0]\n", + " elif(value['age'] > 22 and value['age'] <= 31):\n", + " return [0.0,1.0,0.0,0.0]\n", + " elif(value['age'] > 31 and value['age'] <= 39):\n", + " return [0.0,0.0,1.0,0.0]\n", + " else:\n", + " return [0.0,0.0,0.0,1.0]\n", + " \n", + "data['label'] = data.apply(lambda row: mapAgeToClass2(row), axis=1)\n", + "data.head()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "X = list(data['text'])\n", + "Y = list(data['label'])\n", + "if (torch.cuda.is_available()):\n", + " device = \"cuda:0\"\n", + " torch.cuda.empty_cache()\n", + "else:\n", + " device = \"cpu\"\n", + "device = \"cpu\"\n", + "\n", + "# model = model.to(device)\n", + "\n", + "X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.2)\n", + "X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)\n", + "# .to(device)\n", + "X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)\n", + "# .to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(self, encodings, labels=None):\n", + " self.encodings = encodings\n", + " self.labels = labels\n", + "\n", + " def __getitem__(self, idx):\n", + " item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", + " if self.labels:\n", + " item[\"labels\"] = torch.tensor(self.labels[idx])\n", + " return item\n", + "\n", + " def __len__(self):\n", + " return len(self.encodings[\"input_ids\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = Dataset(X_train_tokenized, y_train)\n", + "val_dataset = Dataset(X_val_tokenized, y_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(p):\n", + " pred, labels = p\n", + " pred = np.argmax(pred, axis=1)\n", + "\n", + " accuracy = accuracy_score(y_true=labels, y_pred=pred)\n", + " recall = recall_score(y_true=labels, y_pred=pred)\n", + " precision = precision_score(y_true=labels, y_pred=pred)\n", + " f1 = f1_score(y_true=labels, y_pred=pred)\n", + "\n", + " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PyTorch: setting up devices\n", + "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n" + ] + } + ], + "source": [ + "args = TrainingArguments(\n", + " output_dir=\"output\",\n", + " evaluation_strategy=\"steps\",\n", + " eval_steps=500,\n", + " per_device_train_batch_size=8,\n", + " per_device_eval_batch_size=8,\n", + " num_train_epochs=3,\n", + " seed=0,\n", + " load_best_model_at_end=True,\n", + " no_cuda=True\n", + ")\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=val_dataset,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ramon/projects/projekt_glebokie/venv/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " FutureWarning,\n", + "***** Running training *****\n", + " Num examples = 80\n", + " Num Epochs = 3\n", + " Instantaneous batch size per device = 8\n", + " Total train batch size (w. parallel, distributed & accumulation) = 8\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 30\n" + ] + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "raw_pred, _, _ = trainer.predict(val_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = np.argmax(raw_pred, axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model typu decoder" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "f4394274b6de412f99b9d08dfb473204abc12afd5637ebb20c9ad8dbd67e97a0" + }, + "kernelspec": { + "display_name": "Python 3.10.1 64-bit ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}