projekt_glebokie/main.ipynb

539 lines
27 KiB
Plaintext
Raw Normal View History

2022-01-30 16:54:18 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n",
"import torch\n",
"from transformers import TrainingArguments, Trainer\n",
"from transformers import BertTokenizer, BertForSequenceClassification\n",
"from transformers import EarlyStoppingCallback\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>topic</th>\n",
" <th>sign</th>\n",
" <th>date</th>\n",
" <th>text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>14,May,2004</td>\n",
" <td>Info has been found (+/- 100 pages,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>13,May,2004</td>\n",
" <td>These are the team members: Drewe...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>12,May,2004</td>\n",
" <td>In het kader van kernfusie op aarde...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>12,May,2004</td>\n",
" <td>testing!!! testing!!!</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3581210</td>\n",
" <td>male</td>\n",
" <td>33</td>\n",
" <td>InvestmentBanking</td>\n",
" <td>Aquarius</td>\n",
" <td>11,June,2004</td>\n",
" <td>Thanks to Yahoo!'s Toolbar I can ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id gender age topic sign date \\\n",
"0 2059027 male 15 Student Leo 14,May,2004 \n",
"1 2059027 male 15 Student Leo 13,May,2004 \n",
"2 2059027 male 15 Student Leo 12,May,2004 \n",
"3 2059027 male 15 Student Leo 12,May,2004 \n",
"4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n",
"\n",
" text \n",
"0 Info has been found (+/- 100 pages,... \n",
"1 These are the team members: Drewe... \n",
"2 In het kader van kernfusie op aarde... \n",
"3 testing!!! testing!!! \n",
"4 Thanks to Yahoo!'s Toolbar I can ... "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv(\"data/blogtext.csv\")\n",
"data = data[:100]\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model typu encoder (BertForSequenceClassification)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model_name = 'bert-base-uncased'\n",
"tokenizer = BertTokenizer.from_pretrained(model_name)\n",
"model = BertForSequenceClassification.from_pretrained(model_name, problem_type=\"multi_label_classification\", num_labels=4)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbqElEQVR4nO3df5RcZZ3n8feHhAShFQhgL0KUCMnOCeMsY5oEZpXpFgcadiTObJhN88OwA5tVJ3N29Lizcd2JJLp7Dq4L4x5wNLOwMIDpZOPqZLUdQEkfRpcfIcqvBhKaiJKIsNBELBFj4Lt/3CdOTVnVVd1V1dU8+bzOqdP3x/Pc+62bW5+6/VT1jSICMzPL1yGdLsDMzNrLQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvXWcpBFJvZ2uo5Mk/YGkpyWVJP12p+uxvDjora0kPSXpvRXLLpP07QPzEXFqRAzX2c5JkkLSzDaV2mmfBVZFRFdEfK9aAxV2SXp0imuz1zkHvRkwDd5A3gaM1GlzFvBm4O2STm9/SZYLB711XPlVv6TFku6X9JKkZyVdnZrdlX7uTcMbZ0o6RNJ/kvQDSc9J+htJR5Zt9wNp3QuS/qJiP1dK2izpFkkvAZelfd8taa+kZyRdK2lW2fZC0oclPSHpp5I+JelkSf831bupvH3Fc6xaq6TZkkrADOBBSU+Oc6hWAH8LDKXp8u3Pk3RXquubkq6TdEvZ+jNSnXslPXiwD5UdbBz0Nt18DvhcRLwJOBnYlJaflX4elYY37gYuS48+4O1AF3AtgKSFwOeBi4HjgSOBEyr2tRTYDBwF3Aq8CnwEOBY4Ezgb+HBFn3OBRcAZwJ8D64FLgLnAbwIDNZ5X1Voj4hcR0ZXa/LOIOLlaZ0mHA8tSnbcCyyveVL4E3AccA1wJXFrW9wTg68CngTnAx4AvSzquRq2WGQe9TYWvpivJvZL2UgRwLb8ETpF0bESUIuKecdpeDFwdEbsiogR8nCIAZ1KE4v+JiG9HxD5gDVB5Y6e7I+KrEfFaRPw8IrZHxD0RsT8ingK+CPxuRZ/PRMRLETECPALcnvb/E+AbQK0PUsertRF/CPwCuJ0itA8F/gWApLcCpwNrImJfRHwb2FLW9xJgKCKG0nO9A7gfOL/BfdvrnIPepsL7I+KoAw9+/Sq53OXAAuBxSdsk/f44bd8C/KBs/gfATKA7rXv6wIqIeBl4oaL/0+UzkhZI+pqkH6fhnP9CcXVf7tmy6Z9Xme+iuvFqbcQKYFN6E3oF+DL/MHzzFmAsPccDyp/b24ALK95s30Xxm44dBDr9AZTZPxIRTwADkg6huIrdLOkYfv1qHOBHFCF2wFuB/RTh+wzwTw+skPQGimGNf7S7ivm/Ar4HDETETyX9GcVvBq0wXq3jknQi8B5gsaR/mRYfDhwm6ViK5zpH0uFlYT+3bBNPAzdHxL9p8jnY65Sv6G1akXSJpOMi4jVgb1r8GvD/0s+3lzXfAHwkfRDZRXEFvjEi9lOMvb9P0u+ksewrAdXZ/RuBl4CSpN8APtSip1Wv1nouBXZSvHGdlh4LgN0Ub0o/oBiKuVLSLElnAu8r638LxbE4V9IMSYdJ6k1vIHYQcNDbdNMPjKRvonwOWJ7Gz18G/jPwnTT8cAZwA3AzxTdyvg+8AvwpQBpD/1NgkOKKtwQ8RzHOXcvHgIuAnwJ/DWxs4fOqWWsDVgCfj4gflz+AL/APwzcXU3yA/ALFh64bSc81Ip6m+OD5P1K8YT4N/Hv8+j9oyP/xiB0M0lX0XmB+RHy/w+W0naSNwOMR8clO12Kd53d0y5ak90k6XNIRFH95+jDwVGerag9Jp6fv9B8iqZ/iCv6rHS7LpgkHveVsKcWHoD8C5lMMA+X6K+w/AYYphqj+O/ChWrdSsIOPh27MzDLnK3ozs8xNu+/RH3vssXHSSSd1bP8/+9nPOOKIIzq2/3pcX3NcX3NcX3PaWd/27dufj4jqt7WIiGn1WLRoUXTS1q1bO7r/elxfc1xfc1xfc9pZH3B/1MhVD92YmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWXOQW9mljkHvZlZ5hz0ZmaZc9CbmWVu2t0Cwczy0dc3tfsbGIC1a6d2nxNRr76tW9uzX1/Rm5llzkFvZpY5B72ZWeYc9GZmmXPQm5llzkFvZpa5hoJeUr+kHZJGJa2usv4sSd+VtF/Ssirr3yRpt6RrW1G0mZk1rm7QS5oBXAecBywEBiQtrGj2Q+Ay4Es1NvMp4K7Jl2lmZpPVyBX9YmA0InZFxD5gEFha3iAinoqIh4DXKjtLWgR0A7e3oF4zM5sgFf/V4DgNiqGY/oi4Is1fCiyJiFVV2t4IfC0iNqf5Q4A7gUuA9wI9NfqtBFYCdHd3LxocHGzmOTWlVCrR1dXVsf3X4/qa4/qaM9H6du5sYzFVzJlTYmxs+h6/evUtWDD5bff19W2PiJ5q69p9C4QPA0MRsVtSzUYRsR5YD9DT0xO9vb1tLqu24eFhOrn/elxfc1xfcyZa31TfjmBgYJgNG3qndqcTUK++dt0CoZGg3wPMLZs/MS1rxJnAuyV9GOgCZkkqRcSvfaBrZmbt0UjQbwPmS5pHEfDLgYsa2XhEXHxgWtJlFEM3DnkzsylU98PYiNgPrAJuAx4DNkXEiKR1ki4AkHS6pN3AhcAXJY20s2gzM2tcQ2P0ETEEDFUsW1M2vY1iSGe8bdwI3DjhCs3MrCn+y1gzs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLnIPezCxzDnozs8w56M3MMuegNzPLXENBL6lf0g5Jo5JWV1l/lqTvStovaVnZ8tMk3S1pRNJDkv5VK4s3M7P66ga9pBnAdcB5wEJgQNLCimY/BC4DvlSx/GXgAxFxKtAP/KWko5qs2czMJmBmA20WA6MRsQtA0iCwFHj0QIOIeCqte628Y0TsLJv+kaTngOOAvc0WbmZmjVFEjN+gGIrpj4gr0vylwJKIWFWl7Y3A1yJic5V1i4GbgFMj4rWKdSuBlQDd3d2LBgcHJ/dsWqBUKtHV1dWx/dfj+prj+poz0fp27qzfppXmzCkxNjZ9j1+9+hYsmPy2+/r6tkdET7V1jVzRN03S8cDNwIrKkAeIiPXAeoCenp7o7e2dirKqGh4eppP7r8f1Ncf1NWei9a1d275aqhkYGGbDht6p3ekE1Ktv69b27LeRD2P3AHPL5k9Myxoi6U3A14FPRMQ9EyvPzMya1UjQbwPmS5onaRawHNjSyMZT+68Af1NtOMfMzNqvbtBHxH5gFXAb8BiwKSJGJK2TdAGApNMl7QYuBL4oaSR1/yPgLOAySQ+kx2nteCJmZlZdQ2P0ETEEDFUsW1M2vY1iSKey3y3ALU3WaGZmTfBfxpqZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmHPRmZplz0JuZZc5Bb2aWOQe9mVnmGgp6Sf2SdkgalbS6yvqzJH1X0n5JyyrWrZD0RHqsaFXhZmbWmLpBL2kGcB1wHrAQGJC0sKLZD4HLgC9V9J0DfBJYAiwGPinp6ObLNjOzRjVyRb8YGI2IXRGxDxgElpY3iIinIuIh4LWKvucCd0TEWES8CNwB9LegbjMza5AiYvwGxVBMf0RckeYvBZZExKoqbW8EvhYRm9P8x4DDIuLTaf4vgJ9HxGcr+q0EVgJ0d3cvGhwcbPZ5TVqpVKKrq6tj+6/H9TXH9TVnovXt3NnGYqqYM6fE2Nj0PX716luwYPLb7uvr2x4RPdXWzZz8ZlsnItYD6wF6enqit7e3Y7UMDw/Tyf3X4/qa4/qaM9H61q5tXy3VDAwMs2FD79TudALq1bd1a3v228jQzR5gbtn8iWlZI5rpa2ZmLdBI0G8D5kuaJ2kWsBzY0uD2bwPOkXR0+hD2nLTMzMymSN2gj4j9wCqKgH4M2BQ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 10000x10000 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"n, bins, patches = plt.hist(data['age'], 4, density=True, facecolor='b', alpha=0.75)\n",
"\n",
"plt.title('Histogram of Age')\n",
"plt.grid(True)\n",
"plt.figure(figsize=(100,100), dpi=100)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>topic</th>\n",
" <th>sign</th>\n",
" <th>date</th>\n",
" <th>text</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>14,May,2004</td>\n",
" <td>Info has been found (+/- 100 pages,...</td>\n",
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>13,May,2004</td>\n",
" <td>These are the team members: Drewe...</td>\n",
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>12,May,2004</td>\n",
" <td>In het kader van kernfusie op aarde...</td>\n",
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2059027</td>\n",
" <td>male</td>\n",
" <td>15</td>\n",
" <td>Student</td>\n",
" <td>Leo</td>\n",
" <td>12,May,2004</td>\n",
" <td>testing!!! testing!!!</td>\n",
" <td>[1.0, 0.0, 0.0, 0.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3581210</td>\n",
" <td>male</td>\n",
" <td>33</td>\n",
" <td>InvestmentBanking</td>\n",
" <td>Aquarius</td>\n",
" <td>11,June,2004</td>\n",
" <td>Thanks to Yahoo!'s Toolbar I can ...</td>\n",
" <td>[0.0, 0.0, 1.0, 0.0]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id gender age topic sign date \\\n",
"0 2059027 male 15 Student Leo 14,May,2004 \n",
"1 2059027 male 15 Student Leo 13,May,2004 \n",
"2 2059027 male 15 Student Leo 12,May,2004 \n",
"3 2059027 male 15 Student Leo 12,May,2004 \n",
"4 3581210 male 33 InvestmentBanking Aquarius 11,June,2004 \n",
"\n",
" text label \n",
"0 Info has been found (+/- 100 pages,... [1.0, 0.0, 0.0, 0.0] \n",
"1 These are the team members: Drewe... [1.0, 0.0, 0.0, 0.0] \n",
"2 In het kader van kernfusie op aarde... [1.0, 0.0, 0.0, 0.0] \n",
"3 testing!!! testing!!! [1.0, 0.0, 0.0, 0.0] \n",
"4 Thanks to Yahoo!'s Toolbar I can ... [0.0, 0.0, 1.0, 0.0] "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"1 - 22 -> 1 klasa\n",
"23 - 31 -> 2 klasa\n",
"32 - 39 -> 3 klasa \n",
"40 - 48 -> 4 klasa\n",
"\"\"\"\n",
"\n",
"def mapAgeToClass(value: pd.DataFrame) -> int:\n",
" if(value['age'] <=22):\n",
" return 1\n",
" elif(value['age'] > 22 and value['age'] <= 31):\n",
" return 2\n",
" elif(value['age'] > 31 and value['age'] <= 39):\n",
" return 3\n",
" else:\n",
" return 4\n",
"\n",
"def mapAgeToClass2(value: pd.DataFrame) -> int:\n",
" if(value['age'] <=22):\n",
" return [1.0,0.0,0.0,0.0]\n",
" elif(value['age'] > 22 and value['age'] <= 31):\n",
" return [0.0,1.0,0.0,0.0]\n",
" elif(value['age'] > 31 and value['age'] <= 39):\n",
" return [0.0,0.0,1.0,0.0]\n",
" else:\n",
" return [0.0,0.0,0.0,1.0]\n",
" \n",
"data['label'] = data.apply(lambda row: mapAgeToClass2(row), axis=1)\n",
"data.head()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"X = list(data['text'])\n",
"Y = list(data['label'])\n",
"if (torch.cuda.is_available()):\n",
" device = \"cuda:0\"\n",
" torch.cuda.empty_cache()\n",
"else:\n",
" device = \"cpu\"\n",
"device = \"cpu\"\n",
"\n",
"# model = model.to(device)\n",
"\n",
"X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.2)\n",
"X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)\n",
"# .to(device)\n",
"X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)\n",
"# .to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, encodings, labels=None):\n",
" self.encodings = encodings\n",
" self.labels = labels\n",
"\n",
" def __getitem__(self, idx):\n",
" item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
" if self.labels:\n",
" item[\"labels\"] = torch.tensor(self.labels[idx])\n",
" return item\n",
"\n",
" def __len__(self):\n",
" return len(self.encodings[\"input_ids\"])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = Dataset(X_train_tokenized, y_train)\n",
"val_dataset = Dataset(X_val_tokenized, y_val)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(p):\n",
" pred, labels = p\n",
" pred = np.argmax(pred, axis=1)\n",
"\n",
" accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
" recall = recall_score(y_true=labels, y_pred=pred)\n",
" precision = precision_score(y_true=labels, y_pred=pred)\n",
" f1 = f1_score(y_true=labels, y_pred=pred)\n",
"\n",
" return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"PyTorch: setting up devices\n",
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
]
}
],
"source": [
"args = TrainingArguments(\n",
" output_dir=\"output\",\n",
" evaluation_strategy=\"steps\",\n",
" eval_steps=500,\n",
" per_device_train_batch_size=8,\n",
" per_device_eval_batch_size=8,\n",
" num_train_epochs=3,\n",
" seed=0,\n",
" load_best_model_at_end=True,\n",
" no_cuda=True\n",
")\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ramon/projects/projekt_glebokie/venv/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" FutureWarning,\n",
"***** Running training *****\n",
" Num examples = 80\n",
" Num Epochs = 3\n",
" Instantaneous batch size per device = 8\n",
" Total train batch size (w. parallel, distributed & accumulation) = 8\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 30\n"
]
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"raw_pred, _, _ = trainer.predict(val_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred = np.argmax(raw_pred, axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model typu decoder"
]
}
],
"metadata": {
"interpreter": {
"hash": "f4394274b6de412f99b9d08dfb473204abc12afd5637ebb20c9ad8dbd67e97a0"
},
"kernelspec": {
"display_name": "Python 3.10.1 64-bit ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}