2024-05-20 00:43:18 +02:00
{
"cells": [
{
"cell_type": "code",
2024-05-20 03:14:08 +02:00
"execution_count": 23,
2024-05-20 00:43:18 +02:00
"id": "36c02fac",
"metadata": {},
2024-05-20 03:14:08 +02:00
"outputs": [],
2024-05-20 00:43:18 +02:00
"source": [
"import os\n",
"import pandas as pd\n",
"import gensim\n",
"from gensim.models import KeyedVectors\n",
"import numpy as np\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense\n",
"import matplotlib.pyplot as plt\n",
"from keras.regularizers import l2"
]
},
{
"cell_type": "markdown",
"id": "db84429c",
"metadata": {},
"source": [
"### Declare path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fc2539ae",
"metadata": {},
"outputs": [],
"source": [
"data_dir_path = 'sport-text-classification-ball-isi-public'\n",
"train_path = os.path.join(data_dir_path, 'train\\\\train.tsv')\n",
"dev_texts_path = os.path.join(data_dir_path, 'dev-0\\\\in.tsv')\n",
"dev_labels_path = os.path.join(data_dir_path, 'dev-0\\\\expected.tsv')\n",
"dev_predicted_path = os.path.join(data_dir_path, 'dev-0\\\\out.tsv')\n",
"test_texts_path = os.path.join(data_dir_path, 'test-A\\\\in.tsv')\n",
"test_predicted_path = os.path.join(data_dir_path, 'test-A\\\\out.tsv')\n",
"word2vec_file_path = 'word2vec_100_3_polish.bin'"
]
},
{
"cell_type": "markdown",
"id": "e4ea0458",
"metadata": {},
"source": [
"### Load files"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4e038df7",
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv(train_path, sep='\\t', usecols=[0, 1], header=None, names=['label', 'text'])\n",
"dev_texts_data = pd.read_csv(dev_texts_path, sep='\\t', usecols=[0], header=None, names=['text'])\n",
"dev_labels_data = pd.read_csv(dev_labels_path, sep='\\t', usecols=[0], header=None, names=['label'])\n",
"test_texts_data = pd.read_csv(test_texts_path, sep='\\t', usecols=[0], header=None, names=['text'])"
]
},
{
"cell_type": "markdown",
"id": "80bcbe49",
"metadata": {},
"source": [
"### Load word2vec"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3d2e114b",
"metadata": {},
"outputs": [],
"source": [
"word2vec = KeyedVectors.load(word2vec_file_path)"
]
},
{
"cell_type": "markdown",
"id": "4ed6fe85",
"metadata": {},
"source": [
"### Preprocess data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "149c6b1f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\numpy\\core\\fromnumeric.py:3464: RuntimeWarning: Mean of empty slice.\n",
" return _methods._mean(a, axis=axis, dtype=dtype,\n"
]
}
],
"source": [
"def text_to_word2vec(text):\n",
" text_vector = np.mean([word2vec[word] for word in text if word in word2vec], axis=0).tolist()\n",
" if np.isnan(text_vector).any() or not isinstance(text_vector, list):\n",
" return np.zeros(word2vec.vector_size)\n",
" return text_vector\n",
"\n",
"def fit_data(column):\n",
" return np.array(column.tolist())\n",
"\n",
"def fit_data_X(text_column):\n",
" text_preprocessed = text_column.apply(lambda x: gensim.utils.simple_preprocess(x))\n",
" vectors = text_preprocessed.apply(lambda x: text_to_word2vec(x))\n",
" return fit_data(vectors)\n",
"\n",
"train_X = fit_data_X(train_data['text'])\n",
"train_Y = fit_data(train_data['label'])\n",
"dev_X = fit_data_X(dev_texts_data['text'])\n",
"dev_Y = fit_data(dev_labels_data['label'])\n",
"test_X = fit_data_X(test_texts_data['text'])"
]
},
{
"cell_type": "markdown",
"id": "1fa44315",
"metadata": {},
"source": [
"### Create model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1eeecf36",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\backend.py:873: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
"\n"
]
}
],
"source": [
"model = Sequential()\n",
"model.add(Dense(128, input_dim=train_X.shape[1], activation='relu'))\n",
"model.add(Dense(64, activation='relu', kernel_regularizer=l2(0.001)))\n",
"model.add(Dense(32, activation='relu'))\n",
"model.add(Dense(1, activation='sigmoid'))"
]
},
{
"cell_type": "markdown",
"id": "c84111a9",
"metadata": {},
"source": [
"### Compile model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a6e56c53",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\optimizers\\__init__.py:309: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
"\n"
]
}
],
"source": [
"model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"id": "ec76b0f6",
"metadata": {},
"source": [
"### Train model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e72a055c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\utils\\tf_utils.py:492: The name tf.ragged.RaggedTensorValue is deprecated. Please use tf.compat.v1.ragged.RaggedTensorValue instead.\n",
"\n",
"WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\engine\\base_layer_utils.py:384: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.\n",
"\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 17s 4ms/step - loss: 0.1955 - accuracy: 0.9319 - val_loss: 0.1569 - val_accuracy: 0.9404\n",
2024-05-20 00:43:18 +02:00
"Epoch 2/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 11s 4ms/step - loss: 0.1393 - accuracy: 0.9471 - val_loss: 0.1337 - val_accuracy: 0.9450\n",
2024-05-20 00:43:18 +02:00
"Epoch 3/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 10s 3ms/step - loss: 0.1264 - accuracy: 0.9523 - val_loss: 0.1410 - val_accuracy: 0.9426\n",
2024-05-20 00:43:18 +02:00
"Epoch 4/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 15s 5ms/step - loss: 0.1189 - accuracy: 0.9543 - val_loss: 0.1231 - val_accuracy: 0.9516\n",
2024-05-20 00:43:18 +02:00
"Epoch 5/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 16s 5ms/step - loss: 0.1133 - accuracy: 0.9570 - val_loss: 0.1206 - val_accuracy: 0.9490\n",
2024-05-20 00:43:18 +02:00
"Epoch 6/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 17s 6ms/step - loss: 0.1076 - accuracy: 0.9588 - val_loss: 0.1220 - val_accuracy: 0.9481\n",
2024-05-20 00:43:18 +02:00
"Epoch 7/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 11s 4ms/step - loss: 0.1039 - accuracy: 0.9605 - val_loss: 0.1125 - val_accuracy: 0.9541\n",
2024-05-20 00:43:18 +02:00
"Epoch 8/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 18s 6ms/step - loss: 0.0997 - accuracy: 0.9620 - val_loss: 0.1123 - val_accuracy: 0.9536\n",
2024-05-20 00:43:18 +02:00
"Epoch 9/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 17s 6ms/step - loss: 0.0964 - accuracy: 0.9639 - val_loss: 0.1092 - val_accuracy: 0.9547\n",
2024-05-20 00:43:18 +02:00
"Epoch 10/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 11s 4ms/step - loss: 0.0936 - accuracy: 0.9645 - val_loss: 0.1120 - val_accuracy: 0.9567\n",
2024-05-20 00:43:18 +02:00
"Epoch 11/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 17s 5ms/step - loss: 0.0906 - accuracy: 0.9656 - val_loss: 0.1170 - val_accuracy: 0.9527\n",
2024-05-20 00:43:18 +02:00
"Epoch 12/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 12s 4ms/step - loss: 0.0882 - accuracy: 0.9670 - val_loss: 0.1171 - val_accuracy: 0.9549\n",
2024-05-20 00:43:18 +02:00
"Epoch 13/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 18s 6ms/step - loss: 0.0854 - accuracy: 0.9681 - val_loss: 0.1120 - val_accuracy: 0.9567\n",
2024-05-20 00:43:18 +02:00
"Epoch 14/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 13s 4ms/step - loss: 0.0830 - accuracy: 0.9688 - val_loss: 0.1171 - val_accuracy: 0.9562\n",
2024-05-20 00:43:18 +02:00
"Epoch 15/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 12s 4ms/step - loss: 0.0810 - accuracy: 0.9695 - val_loss: 0.1226 - val_accuracy: 0.9510\n",
2024-05-20 00:43:18 +02:00
"Epoch 16/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 16s 5ms/step - loss: 0.0791 - accuracy: 0.9704 - val_loss: 0.1167 - val_accuracy: 0.9567\n",
2024-05-20 00:43:18 +02:00
"Epoch 17/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 14s 4ms/step - loss: 0.0776 - accuracy: 0.9709 - val_loss: 0.1264 - val_accuracy: 0.9532\n",
2024-05-20 00:43:18 +02:00
"Epoch 18/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 14s 4ms/step - loss: 0.0755 - accuracy: 0.9714 - val_loss: 0.1191 - val_accuracy: 0.9519\n",
2024-05-20 00:43:18 +02:00
"Epoch 19/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 13s 4ms/step - loss: 0.0742 - accuracy: 0.9722 - val_loss: 0.1190 - val_accuracy: 0.9545\n",
2024-05-20 00:43:18 +02:00
"Epoch 20/20\n",
2024-05-20 03:14:08 +02:00
"3067/3067 [==============================] - 13s 4ms/step - loss: 0.0725 - accuracy: 0.9732 - val_loss: 0.1295 - val_accuracy: 0.9552\n"
2024-05-20 00:43:18 +02:00
]
}
],
"source": [
"history = model.fit(train_X, train_Y, epochs=20, batch_size=32, validation_data=(dev_X, dev_Y))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "561f4db3",
"metadata": {},
"outputs": [
{
"data": {
2024-05-20 03:14:08 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHFCAYAAAAaD0bAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAACLDUlEQVR4nOzdd3QU5dfA8e+mF9JIpyShh94JvUjvCAoo0kFRQRB9VfSHig0UARuglFCVJkUUpBfpobdA6CSEhJBAOmm78/4xZGFJIYEkm3I/5+xhM/vMzJ3Mhr37VI2iKApCCCGEECWIibEDEEIIIYQoaJIACSGEEKLEkQRICCGEECWOJEBCCCGEKHEkARJCCCFEiSMJkBBCCCFKHEmAhBBCCFHiSAIkhBBCiBJHEiAhhBBClDiSAAmRhcWLF6PRaNBoNOzZsyfD64qiULlyZTQaDW3bts3Tc2s0Gj7//PNc73fjxg00Gg2LFy/O8T5nz55Fo9Fgbm5OWFhYrs9Z0iUkJDBt2jTq169PqVKlsLW1pV69enzzzTckJCQYO7wMhg0bpn9fZ/YwtvS/u2PHjhk7FFHMmRk7ACEKOzs7OxYuXJghydm7dy9Xr17Fzs7OOIHlkQULFgCQlpbG0qVL+fDDD40cUdFx584dOnTowNWrV3nnnXf47rvvANi1axdfffUVK1asYMeOHbi7uxs5UkPW1tbs2rXL2GEIYVSSAAnxFAMGDOD3339n9uzZ2Nvb67cvXLiQZs2aERsba8Tonk9ycjK///47devWJTIyEn9//0KbAD148AArK6tCUUuRbsiQIVy8eJHdu3fTsmVL/faOHTvSvXt32rVrx9ChQ9myZUuBxvXgwQOsra2zfN3ExISmTZsWYERCFD7SBCbEU7zyyisArFixQr8tJiaGtWvXMmLEiEz3uXfvHm+99RZly5bFwsKCihUr8sknn5CcnGxQLjY2ltGjR+Ps7EypUqXo0qULly5dyvSYly9f5tVXX8XNzQ1LS0uqV6/O7Nmzn+vaNmzYQFRUFKNGjWLo0KFcunSJ/fv3ZyiXnJzMF198QfXq1bGyssLZ2Zl27dpx8OBBfRmdTsfPP/9MvXr1sLa2xtHRkaZNm7Jx40Z9maya9nx8fBg2bJj+5/RmkG3btjFixAhcXV2xsbEhOTmZK1euMHz4cKpUqYKNjQ1ly5alZ8+enD17NsNxo6Ojee+996hYsSKWlpa4ubnRrVs3Ll68iKIoVKlShc6dO2fYLz4+HgcHB95+++0sf3fHjh1j27ZtjBw50iD5SdeyZUtGjBjB1q1bOX78OAD169enVatWGcpqtVrKli1L37599dtSUlL46quv8PX1xdLSEldXV4YPH87du3cz/O569OjBunXrqF+/PlZWVkyZMiXLuHNqz549aDQali9fzsSJE/Hw8MDa2po2bdpw8uTJDOU3btxIs2bNsLGxwc7Ojo4dO3Lo0KEM5S5evMgrr7yCu7s7lpaWeHl5MWTIkAx/G3Fxcbz55pu4uLjg7OxM3759uX37tkGZXbt20bZtW5ydnbG2tsbLy4t+/fqRmJj43Ncvij9JgIR4Cnt7e1566SX8/f3121asWIGJiQkDBgzIUD4pKYl27dqxdOlSJk6cyKZNm3jttdf47rvvDD7gFEWhT58+LFu2jPfee4/169fTtGlTunbtmuGYgYGBNG7cmHPnzjFjxgz++ecfunfvzjvvvPNcH3YLFy7E0tKSQYMGMWLECDQaDQsXLjQok5aWRteuXfnyyy/p0aMH69evZ/HixTRv3pzg4GB9uWHDhjF+/HgaN27MqlWrWLlyJb169eLGjRvPHN+IESMwNzdn2bJl/Pnnn5ibm3P79m2cnZ2ZNm0aW7ZsYfbs2ZiZmeHn50dQUJB+37i4OFq2bMlvv/3G8OHD+fvvv/n111+pWrUqYWFhaDQaxo0bx/bt27l8+bLBeZcuXUpsbGy2CdD27dsB6NOnT5Zl0l9LLzt8+HD279+f4Xzbtm3j9u3bDB8+HFCTyd69ezNt2jReffVVNm3axLRp09i+fTtt27blwYMHBvufOHGC//u//+Odd95hy5Yt9OvXL/tfLOp9ffKh0+kylPv444+5du0aCxYsYMGCBdy+fZu2bdty7do1fZk//viD3r17Y29vz4oVK1i4cCH379+nbdu2Bgn16dOnady4MYcPH+aLL77g33//ZerUqSQnJ5OSkmJw3lGjRmFubs4ff/zBd999x549e3jttdf0r9+4cYPu3btjYWGBv78/W7ZsYdq0adja2mY4lhCZUoQQmVq0aJECKEePHlV2796tAMq5c+cURVGUxo0bK8OGDVMURVFq1qyptGnTRr/fr7/+qgDK6tWrDY737bffKoCybds2RVEU5d9//1UA5ccffzQo9/XXXyuA8tlnn+m3de7cWSlXrpwSExNjUHbs2LGKlZWVcu/ePUVRFOX69esKoCxatOip13fjxg3FxMREGThwoH5bmzZtFFtbWyU2Nla/benSpQqgzJ8/P8tj/ffffwqgfPLJJ9me88nrSuft7a0MHTpU/3P6737IkCFPvY60tDQlJSVFqVKlivLuu+/qt3/xxRcKoGzfvj3LfWNjYxU7Oztl/PjxBttr1KihtGvXLtvzjhkzRgGUixcvZlnmwoULCqC8+eabiqIoSmRkpGJhYaF8/PHHBuX69++vuLu7K6mpqYqiKMqKFSsUQFm7dq1BuaNHjyqAMmfOHP02b29vxdTUVAkKCso23nRDhw5VgEwf7du315dLf883aNBA0el0+u03btxQzM3NlVGjRimKoiharVYpU6aMUrt2bUWr1erLxcXFKW5ubkrz5s3121544QXF0dFRiYiIyDK+9Hv/1ltvGWz/7rvvFEAJCwtTFEVR/vzzTwVQTp06laPrFuJJUgMkRA60adOGSpUq4e/vz9mzZzl69GiWzV+7du3C1taWl156yWB7ehPPzp07Adi9ezcAgwYNMij36quvGvyclJTEzp07efHFF7GxsTH4xt6tWzeSkpI4fPhwrq9p0aJF6HQ6g+sYMWIECQkJrFq1Sr/t33//xcrKKsvrTS8DZFtj8iwyq8lIS0vjm2++oUaNGlhYWGBmZoaFhQWXL1/mwoULBjFVrVqVDh06ZHl8Ozs7hg8fzuLFi/Ujtnbt2kVgYCBjx4597vgVRQHQ91tydnamZ8+eLFmyRF/bcv/+ff766y+GDBmCmZnaLfOff/7B0dGRnj17GtzvevXq4eHhkWFUYp06dahatWqO47K2tubo0aMZHnPmzMlQ9tVXXzXod+Xt7U3z5s3179+goCBu377N4MGDMTF59JFSqlQp+vXrx+HDh0lMTCQxMZG9e/fSv39/XF1dnxpjr169MlwjwM2bNwGoV68eFhYWvP766yxZssSgRkqInJAESIgc0Gg0DB8+nOXLl+ubUTLrywEQFRWFh4dHhs66bm5umJmZERUVpS9nZmaGs7OzQTkPD48Mx0tLS+Pnn3/G3Nzc4NGtWzcAIiMjc3U9Op2OxYsXU6ZMGRo2bEh0dDTR0dF06NABW1tbg2awu3fvUqZMGYMPtyfdvXsXU1PTDLE/L09PzwzbJk6cyOTJk+nTpw9///03R44c4ejRo9StW9egaeju3buUK1fuqecYN24ccXFx/P777wD88ssvlCtXjt69e2e7n5eXFwDXr1/Pskx681/58uX120aMGEFoaKi+WWzFihUkJycb9IG6c+cO0dHRWFhYZLjn4eHhGe53Zr+n7JiYmNCoUaMMj8ySqMzuqYeHh8H7OKsYypQpg06n4/79+9y/fx+tVpujewJk+LuwtLQE0N/jSpUqsWPHDtzc3Hj77bepVKkSlSpV4scff8zR8YWQUWBC5NCwYcP49NNP+fXXX/n666+zLOfs7MyRI0dQFMUgCYqIiCAtLQ0XFxd9ubS0NKKiogz+sw8PDzc4npOTE6ampgwePDjLGpYKFSrk6lp27Nih/yb95AcNwOHDhwkMDKRGjRq4urqyf/9+dDpdlkmQq6srWq2W8PDwbD+MLS0tM3R2hUcfok/KbMTX8uXLGTJkCN9
2024-05-20 00:43:18 +02:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['accuracy'], label='Train Accuracy')\n",
"plt.plot(history.history['val_accuracy'], label='Validation Accuracy')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Accuracy')\n",
"plt.title('Model Accuracy Over Epochs')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "4d0b9315",
"metadata": {},
"source": [
"### Predict and save results"
]
},
{
"cell_type": "code",
2024-05-20 03:14:08 +02:00
"execution_count": 26,
2024-05-20 00:43:18 +02:00
"id": "54f93a9b",
"metadata": {},
"outputs": [],
"source": [
"def predict_and_save(X, filename):\n",
" Y_predicted = model.predict(X)\n",
2024-05-20 03:14:08 +02:00
" Y_predicted = np.round(Y_predicted,0).astype(int)\n",
2024-05-20 00:43:18 +02:00
" Y_predicted_df = pd.DataFrame(Y_predicted, columns=['predicted_label'])\n",
" Y_predicted_df.to_csv(filename, sep='\\t', index=False, header=None)"
]
},
{
"cell_type": "code",
2024-05-20 03:14:08 +02:00
"execution_count": 27,
2024-05-20 00:43:18 +02:00
"id": "9d3b3867",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-20 03:14:08 +02:00
"171/171 [==============================] - 0s 3ms/step\n",
"171/171 [==============================] - 1s 3ms/step\n"
2024-05-20 00:43:18 +02:00
]
}
],
"source": [
"dev_predicted = predict_and_save(dev_X, dev_predicted_path)\n",
"test_predicted = predict_and_save(test_X, test_predicted_path)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}