dl_projekt/lstm.ipynb

837 lines
97 KiB
Plaintext
Raw Normal View History

2024-06-03 16:32:11 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# LSTM"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"train = pd.read_csv(\"train.csv\")\n",
"test = pd.read_csv(\"test.csv\")\n",
"valid = pd.read_csv(\"valid.csv\")\n",
"\n",
"train.loc[train[\"review_score\"]==-1, \"review_score\"]=0\n",
"test.loc[test[\"review_score\"]==-1, \"review_score\"]=0\n",
"valid.loc[valid[\"review_score\"]==-1, \"review_score\"]=0"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Sprawdzanie długości najdłuższej recenzji (teoretycznie Steam zezwala na max 8000 znaków)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"train[\"seq_length\"] = train[\"review_text\"].apply(lambda x : len(x.split()))"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 5,
2024-06-03 16:32:11 +02:00
"outputs": [
{
2024-06-03 17:09:59 +02:00
"data": {
"text/plain": "count 43230.000000\nmean 74.154962\nstd 127.088261\nmin 0.000000\n25% 12.000000\n50% 31.000000\n75% 80.000000\nmax 1570.000000\nName: seq_length, dtype: float64"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train[\"seq_length\"].describe()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Niektóre recenzje są bardzo długie ale większość jest poniżej 100 słów. W celu przyspieszenia treningu usunę z zestawu treningowego te przykłady, które są dłuższe.\n",
"\n",
"*Notka: najpierw próbowałem wytrenować model na sekwencjach długości 1600 tokenów (większych niż najdłuższa recenzja). Model się bardzo długo i bardzo źle trenował.*"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"#train.drop(train[\"seq_length\"]>200, inplace=True)\n",
"train.drop(train[train.seq_length > 200].index, inplace=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "count 39571.000000\nmean 44.135124\nstd 44.780534\nmin 0.000000\n25% 11.000000\n50% 27.000000\n75% 62.000000\nmax 200.000000\nName: seq_length, dtype: float64"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
2024-06-03 16:32:11 +02:00
}
],
"source": [
2024-06-03 17:09:59 +02:00
"train[\"seq_length\"].describe()"
2024-06-03 16:32:11 +02:00
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 9,
2024-06-03 16:32:11 +02:00
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
2024-06-03 17:09:59 +02:00
"SEQ_PADDED_LENGTH = 200\n",
2024-06-03 16:32:11 +02:00
"VOCABULARY_SIZE = 4000\n",
"vectorizer = tf.keras.layers.TextVectorization(output_sequence_length=SEQ_PADDED_LENGTH, max_tokens=VOCABULARY_SIZE)\n",
"vectorizer.adapt(train[\"review_text\"])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 10,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
"text/plain": "4000"
},
2024-06-03 17:09:59 +02:00
"execution_count": 10,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vectorizer.get_vocabulary())"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 11,
2024-06-03 16:32:11 +02:00
"outputs": [],
"source": [
"train[\"vectorized\"] = train[\"review_text\"].apply(vectorizer)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 12,
2024-06-03 16:32:11 +02:00
"outputs": [],
"source": [
"test[\"vectorized\"] = test[\"review_text\"].apply(vectorizer)\n",
"valid[\"vectorized\"] = valid[\"review_text\"].apply(vectorizer)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 13,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-03 17:09:59 +02:00
"Model: \"model\"\n",
2024-06-03 16:32:11 +02:00
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
2024-06-03 17:09:59 +02:00
" input_1 (InputLayer) [(None, 200)] 0 \n",
" \n",
" embedding (Embedding) (None, 200, 128) 512128 \n",
" \n",
" bidirectional (Bidirectiona (None, 200, 128) 98816 \n",
" l) \n",
2024-06-03 16:32:11 +02:00
" \n",
2024-06-03 17:09:59 +02:00
" dropout (Dropout) (None, 200, 128) 0 \n",
2024-06-03 16:32:11 +02:00
" \n",
2024-06-03 17:09:59 +02:00
" bidirectional_1 (Bidirectio (None, 128) 98816 \n",
" nal) \n",
2024-06-03 16:32:11 +02:00
" \n",
2024-06-03 17:09:59 +02:00
" dense (Dense) (None, 1) 129 \n",
2024-06-03 16:32:11 +02:00
" \n",
"=================================================================\n",
2024-06-03 17:09:59 +02:00
"Total params: 709,889\n",
"Trainable params: 709,889\n",
2024-06-03 16:32:11 +02:00
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"from keras.optimizers import Adam\n",
"import keras.layers as layers\n",
"import keras\n",
"\n",
"\n",
"def create_model():\n",
" input_layer = layers.Input(shape=(SEQ_PADDED_LENGTH,))\n",
2024-06-03 17:09:59 +02:00
" embedding_layer = layers.Embedding(input_dim=VOCABULARY_SIZE+1, output_dim=128, input_length=SEQ_PADDED_LENGTH)(input_layer)\n",
" #lstm_layer = layers.LSTM(64)(embedding_layer)\n",
" lstm_layer = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(embedding_layer)\n",
" dropout_layer = layers.Dropout(0.5)(lstm_layer)\n",
" lstm_layer_2 = layers.Bidirectional(layers.LSTM(64))(dropout_layer)\n",
" output_layer = layers.Dense(1,activation=\"sigmoid\")(lstm_layer_2)\n",
2024-06-03 16:32:11 +02:00
" model = keras.Model(inputs=input_layer, outputs=output_layer)\n",
" model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=1e-3), metrics=['accuracy'])\n",
" return model\n",
"model = create_model()\n",
"model.summary()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 14,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
2024-06-03 17:09:59 +02:00
"text/plain": "TensorShape([200])"
2024-06-03 16:32:11 +02:00
},
2024-06-03 17:09:59 +02:00
"execution_count": 14,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.iloc[120][\"vectorized\"].shape"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 15,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
2024-06-03 17:09:59 +02:00
"text/plain": "[200]"
2024-06-03 16:32:11 +02:00
},
2024-06-03 17:09:59 +02:00
"execution_count": 15,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.iloc[120][\"vectorized\"].get_shape().as_list()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 16,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
2024-06-03 17:09:59 +02:00
"text/plain": "<tf.Tensor: shape=(200,), dtype=int64, numpy=\narray([ 225, 1120, 2, 113, 1, 1816, 3, 108, 97, 1417, 23,\n 12, 52, 19, 257, 10, 3, 52, 34, 8, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0], dtype=int64)>"
2024-06-03 16:32:11 +02:00
},
2024-06-03 17:09:59 +02:00
"execution_count": 16,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.iloc[120][\"vectorized\"]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Część recenzji nie zawierała tekstu więc po usunięciu interpunkcji i znaków specjalnych były puste, teksty te trzeba usunąć z materiału treningowego"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 17,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
2024-06-03 17:09:59 +02:00
"text/plain": "shapes\n200 39452\n0 119\nName: count, dtype: int64"
2024-06-03 16:32:11 +02:00
},
2024-06-03 17:09:59 +02:00
"execution_count": 17,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train[\"shapes\"] = train[\"vectorized\"].apply(lambda x : x.get_shape().as_list()[0])\n",
"train[\"shapes\"].value_counts()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 18,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
2024-06-03 17:09:59 +02:00
"text/plain": "shapes\n200 39452\nName: count, dtype: int64"
2024-06-03 16:32:11 +02:00
},
2024-06-03 17:09:59 +02:00
"execution_count": 18,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.drop(train[train[\"vectorized\"].map(lambda x : x.get_shape().as_list()[0])!=SEQ_PADDED_LENGTH].index, inplace=True)\n",
"train[\"shapes\"].value_counts()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 19,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
"text/plain": " Unnamed: 0 review_text review_score vectorized\n42 4552590 !!! 1 ()\n124 5286261 . 1 ()\n259 4934066 ........ 1 ()\n468 5584357 . 1 ()\n717 2172088 =] 1 ()",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>Unnamed: 0</th>\n <th>review_text</th>\n <th>review_score</th>\n <th>vectorized</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>42</th>\n <td>4552590</td>\n <td>!!!</td>\n <td>1</td>\n <td>()</td>\n </tr>\n <tr>\n <th>124</th>\n <td>5286261</td>\n <td>.</td>\n <td>1</td>\n <td>()</td>\n </tr>\n <tr>\n <th>259</th>\n <td>4934066</td>\n <td>........</td>\n <td>1</td>\n <td>()</td>\n </tr>\n <tr>\n <th>468</th>\n <td>5584357</td>\n <td>.</td>\n <td>1</td>\n <td>()</td>\n </tr>\n <tr>\n <th>717</th>\n <td>2172088</td>\n <td>=]</td>\n <td>1</td>\n <td>()</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
2024-06-03 17:09:59 +02:00
"execution_count": 19,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#valid.drop(valid[valid[\"vectorized\"].map(lambda x : x.get_shape().as_list()[0])!=1600].index, inplace=True)\n",
"\n",
"empty_valid = valid[valid[\"vectorized\"].map(lambda x : x.get_shape().as_list()[0])==0]\n",
"empty_valid.head()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"\"0\" to maskowane pozycje, puste dane w zbiorze testowym można nimi uzupełnić"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 20,
2024-06-03 16:32:11 +02:00
"outputs": [],
"source": [
"#test.loc[test[\"vectorized\"].map(lambda x : x.get_shape().as_list()[0])!=SEQ_PADDED_LENGTH,\"vectorized\"] = tf.zeros((SEQ_PADDED_LENGTH,), dtype=tf.dtypes.int64)\n",
"#valid.loc[valid[\"vectorized\"].map(lambda x : x.get_shape().as_list()[0])!=SEQ_PADDED_LENGTH,\"vectorized\"] = tf.zeros((SEQ_PADDED_LENGTH,), dtype=tf.dtypes.int64)\n",
"#empty_valid[\"vectorized\"] = tf.zeros((len(empty_valid.index),1600), dtype=tf.dtypes.int64)\n",
"#empty_test[\"vectorized\"] = tf.zeros((len(empty_test.index),1600), dtype=tf.dtypes.int64)\n",
"\n",
"#empty_valid[\"vectorized\"].iloc[0]\n",
"\n",
"def vector_fix(x):\n",
" if x.get_shape().as_list()[0]==SEQ_PADDED_LENGTH:\n",
" return x\n",
2024-06-03 17:09:59 +02:00
" return tf.zeros((SEQ_PADDED_LENGTH,), dtype=tf.dtypes.int64)\n",
2024-06-03 16:32:11 +02:00
"\n",
"test[\"vectorized\"] = test[\"vectorized\"].apply(vector_fix)\n",
"valid[\"vectorized\"] = valid[\"vectorized\"].apply(vector_fix)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 21,
2024-06-03 16:32:11 +02:00
"outputs": [],
"source": [
"#train[\"vectorized\"] = train[\"vectorized\"].apply(lambda x : x.numpy())\n",
"#valid[\"vectorized\"] = valid[\"vectorized\"].apply(lambda x : x.numpy())\n",
"#test[\"vectorized\"] = test[\"vectorized\"].apply(lambda x : x.numpy())"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 22,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"data": {
2024-06-03 17:09:59 +02:00
"text/plain": "<tf.Tensor: shape=(200,), dtype=int64, numpy=\narray([ 41, 50, 1864, 20, 2, 201, 3, 90, 27, 98, 47,\n 4, 243, 50, 381, 184, 7, 139, 408, 71, 10, 5,\n 120, 14, 2, 688, 2, 3, 9, 48, 1, 30, 85,\n 31, 7, 314, 87, 12, 577, 6, 494, 10, 3, 63,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0], dtype=int64)>"
2024-06-03 16:32:11 +02:00
},
2024-06-03 17:09:59 +02:00
"execution_count": 22,
2024-06-03 16:32:11 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.iloc[0][\"vectorized\"]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-03 17:09:59 +02:00
"execution_count": 23,
2024-06-03 16:32:11 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
2024-06-03 17:09:59 +02:00
"1233/1233 [==============================] - 288s 230ms/step - loss: 0.4453 - accuracy: 0.7923 - val_loss: 0.3532 - val_accuracy: 0.8514\n",
2024-06-03 16:32:11 +02:00
"Epoch 2/3\n",
2024-06-03 17:09:59 +02:00
"1233/1233 [==============================] - 289s 235ms/step - loss: 0.3145 - accuracy: 0.8669 - val_loss: 0.3272 - val_accuracy: 0.8519\n",
"Epoch 3/3\n",
"1233/1233 [==============================] - 289s 234ms/step - loss: 0.2684 - accuracy: 0.8875 - val_loss: 0.3216 - val_accuracy: 0.8635\n"
2024-06-03 16:32:11 +02:00
]
}
],
"source": [
"#train_y = np.stack(train[\"review_score\"].values)\n",
"train_y = np.stack(train[\"review_score\"].values)\n",
"valid_y = np.stack(valid[\"review_score\"].values)\n",
"\n",
"test_y = np.stack(test[\"review_score\"].values)\n",
"\n",
"###\n",
"#train_x = np.stack(train[\"vectorized\"].values)\n",
"train_x = np.stack(train[\"vectorized\"].values)\n",
"\n",
"test_x = np.stack(test[\"vectorized\"].values)\n",
"valid_x = np.stack(valid[\"vectorized\"].values)\n",
"\n",
"\n",
"#callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=3, restore_best_weights=True)\n",
"history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=3)"
],
"metadata": {
"collapsed": false
}
},
2024-06-03 17:09:59 +02:00
{
"cell_type": "code",
"execution_count": 24,
"outputs": [],
"source": [
"model.save(\"lstm_model.keras\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [
{
"data": {
"text/plain": "<matplotlib.legend.Legend at 0x1bf88e819a0>"
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAHHCAYAAABEEKc/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB420lEQVR4nO3deVhUdfsG8HtmYIZ933EUUHMXlC3NNTG3LE0Tzdy3Fu1V8lfa4lqh1Vu2mFq5ZYtLqWX6Uoa7IQKKoikpooCyiuwywMz5/TE6OiwKyHBY7s91zZWc8z1nni+DcneW50gEQRBARERERDpSsQsgIiIiamgYkIiIiIjKYUAiIiIiKocBiYiIiKgcBiQiIiKichiQiIiIiMphQCIiIiIqhwGJiIiIqBwGJCIS3bVr17BkyRLExcVVOUatVuPDDz/Er7/+Wo+VEVFzxYBERKIqKyvD2LFjcebMGXTq1KnKce+88w7WrFmDxx9//KH7TE9Px+jRo2Fvbw+JRIJVq1bVYcX3bNq0CRKJBNHR0XW2r6tXr+qW9evXD/369XvkfRNRzTEgETUi27dvh0Qiwa5duyqs8/b2hkQiwcGDByusa9myJXr27FlndXzwwQfYvXt3nezrnXfeAQD8+OOPkEor/ydp3759+PbbbxEWFgZnZ+eH7nPevHn4448/sHDhQmzZsgWDBw+uk1obqxs3bmDJkiWIjY01yP5//PFHg4VQIrEwIBE1Ir169QIAHDt2TG95Xl4ezp07ByMjIxw/flxvXXJyMpKTk3Xb1oW6CkgFBQUwNzfHnj17YGpqWuW4hIQE7N27F+3atavWfg8cOIBnn30W8+fPx4svvoj27ds/cq2GNmHCBNy+fRutWrXSLfvzzz/x559/PvK+b9y4gaVLlzIgEdWAkdgFEFH1ubm5wdPTs0JAioiIgCAIeP755yusu/v1owYkQRBQXFz8wCBTUxYWFnj33XcfOm7OnDk12m9GRgZsbGxqWZU4ZDIZZDKZ3jK5XC5KLUVFRTAzMxPlvYkaCh5BImpkevXqhdOnT+P27du6ZcePH0enTp0wZMgQnDhxAhqNRm+dRCLBE088AQDYuHEjnnzySTg5OUGhUKBjx45Ys2ZNhffx8PDA008/jT/++AN+fn4wNTXFunXrIJFIUFhYiM2bN0MikUAikWDy5Mm67U6fPo0hQ4bAysoKFhYWGDBgAE6cOKG379LSUixduhRt27aFiYkJ7O3t0atXL+zfv19v3MWLFzFmzBg4OjrC1NQU7dq1w9tvv13l9+budTyCIGD16tW6+gBgyZIluj9Xts391/7cnfuxY8cQEBAAExMTeHl54bvvvqvyve+6desWAgIC0KJFC8THx1d7Lo9yDdL+/fvRq1cv2NjYwMLCAu3atcNbb70FADh06BD8/f0BAFOmTNF9TzZt2qR7j86dOyMmJgZ9+vSBmZmZbttff/0Vw4YNg5ubGxQKBVq3bo3ly5dDrVbr1bh3715cu3ZNt28PDw/d0cH//Oc/FepNSUmBTCZDaGjoQ+dGJBYeQSJqZHr16oUtW7YgMjJS98vz+PHj6NmzJ3r27Inc3FycO3cOXbt21a1r37497O3tAQBr1qxBp06d8Mwzz8DIyAh79uzBK6+8Ao1Gg1dffVXvveLj4zFu3DjMmjULM2bMQLt27bBlyxZMnz4dAQEBmDlzJgCgdevWAIDz58+jd+/esLKywhtvvAFjY2OsW7cO/fr1w+HDhxEYGAhAG1ZCQ0N1+8nLy0N0dDROnTqFgQMHAgDOnj2L3r17w9jYGDNnzoSHhwcSEhKwZ88evP/++5V+b/r06YMtW7ZgwoQJGDhwICZOnFjr7/Ply5cxevRoTJs2DZMmTcKGDRswefJk+Pr6VnkxeVZWFgYOHIjs7GwcPnxY932pzVyq6/z583j66afRtWtXLFu2DAqFApcvX9adau3QoQOWLVuGRYsWYebMmejduzcA6F2TdvPmTQwZMgRjx47Fiy++qLvOa9OmTbCwsEBISAgsLCxw4MABLFq0CHl5efjoo48AAG+//TZyc3ORkpKCTz/9FID2yKCFhQVGjhyJbdu24ZNPPtE7OvbTTz9BEASMHz/+keZOZFACETUq58+fFwAIy5cvFwRBEEpLSwVzc3Nh8+bNgiAIgrOzs7B69WpBEAQhLy9PkMlkwowZM3TbFxUVVdjnoEGDBC8vL71lrVq1EgAIYWFhFcabm5sLkyZNqrB8xIgRglwuFxISEnTLbty4IVhaWgp9+vTRLfP29haGDRv2wHn26dNHsLS0FK5du6a3XKPRPHA7QRAEAMKrr76qt2zx4sVCZf/kbdy4UQAgJCYm6pbdnfuRI0d0yzIyMgSFQiG8/vrrFbaNiooSUlNThU6dOgleXl7C1atXazyXyuro27ev0Ldv3wfO9dNPPxUACJmZmVWOiYqKEgAIGzdurLCub9++AgBh7dq1FdZV9rMya9YswczMTCguLtYtGzZsmNCqVasKY//44w8BgPC///1Pb3nXrl0fOi8isfEUG1Ej06FDB9jb2+uuLTpz5gwKCwt1RwR69uypO3oQEREBtVqtd/3R/dcQ5ebmIisrC3379sWVK1eQm5ur916enp4YNGhQtepSq9X4888/MWLECHh5eemWu7q64oUXXsCxY8eQl5cHALCxscH58+dx6dKlSveVmZmJI0eOYOrUqWjZsqXeuspOkxlCx44ddUdbAMDR0RHt2rXDlStXKoxNSUlB3759UVpaiiNHjuhdaG3oudy91urXX3/VO7VaEwqFAlOmTKmw/P6flfz8fGRlZaF3794oKirCxYsXH7rfoKAguLm54YcfftAtO3fuHM6ePYsXX3yxVrUS1RcGJKJGRiKRoGfPnrprjY4fPw4nJye0adMGgH5Auvvf+wPS8ePHERQUBHNzc9jY2MDR0VF3zUllAam6MjMzUVRUVOmdZh06dIBGo0FycjIAYNmyZcjJycFjjz2GLl264P/+7/9w9uxZ3fi7IaRz587Vfv+6Vj7MAICtrS1u3bpVYfmECROQkZGBw4cPw93dXW+doecSHByMJ554AtOnT4ezszPGjh2L7du31ygsubu7V3pB+Pnz5zFy5EhYW1vDysoKjo6OumBT/melMlKpFOPHj8fu3btRVFQEAPjhhx9gYmKC559/vtr1EYmBAYmoEerVqxdyc3MRFxenu/7orp49e+LatWu4fv06jh07Bjc3N90RnYSEBAwYMABZWVn45JNPsHfvXuzfvx/z5s0DgAq/VOvyjrX79enTBwkJCdiwYQM6d+6Mb7/9Ft27d8e3335rkPcDqj5ac/8Fx/crf0fZXYIgVFj23HPPIScnB5999lntC6wlU1NTHDlyBH/99RcmTJiAs2fPIjg4GAMHDqxybpXto7ycnBz07dsXZ86cwbJly7Bnzx7s378fK1euBFDxZ6UqEydOREFBAXbv3g1BEPDjjz/i6aefhrW1dfUnSSQCXqRN1Ajd3w/p+PHjmDt3rm6dr68vFAoFDh06hMjISAwdOlS3bs+ePVCpVPjtt9/0jpBU1lzyQSoLG46OjjAzM9O7c+uuixcvQiqVQqlU6pbZ2dlhypQpmDJlCgoKCtCnTx8sWbIE06dP1wW6c+fO1aiuB7G1tQWg/cV/fwuAa9euPfK+58yZgzZt2mDRokWwtrbGggULdOsMMZfypFIpBgwYgAEDBuCTTz7BBx98gLfffhsHDx5EUFBQrU7lHTp0CDdv3sTOnTvRp08f3fLExMQKYx+0/86dO6Nbt2744Ycf0KJFCyQlJeGLL76ocT1E9Y1HkIgaIT8/P5iYmOCHH37A9evX9Y4gKRQKdO/eHatXr0ZhYaHe6bW7R0XuPwqSm5uLjRs31uj9zc3NkZOTo7dMJpPhqaeewq+//qp3q3p6ejp+/PFH9OrVC1ZWVgC0d03dz8LCAm3atIFKpQKgDVt9+vTBhg0bkJSUpDe2siM41XH3jrIjR47olt1tV1AX3n33XcyfPx8LFy7Ua5tgiLncLzs7u8IyHx8fANB9P83NzQGgwmf2IJX9rJSUlOCrr76qMNb
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"plt.plot(history.history['loss'])\n",
"plt.plot(history.history['val_loss'])\n",
"plt.title('Wartość funkcji straty')\n",
"plt.ylabel('Strata')\n",
"plt.xlabel('Epoka')\n",
"plt.legend(['train', 'test'], loc='upper left')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [
{
"data": {
"text/plain": "<matplotlib.legend.Legend at 0x1bf8b19f490>"
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHHCAYAAABXx+fLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAABblElEQVR4nO3deVhUZf8G8HvYQTaRZRARcEMFXFESNa1Q0OKnLW6VgqZWaplmiSaaWdJiRqlp9bpk7rlkb24pir4uoeKSKKggihubyr7PPL8/gIERVBgHZmDuz3VxJec8c+b7MOjcnTnf80iEEAJEREREOkRP0wUQERER1TcGICIiItI5DEBERESkcxiAiIiISOcwABEREZHOYQAiIiIincMARERERDqHAYiIiIh0DgMQERER6RwGICKqV9evX4dEIsGaNWtq/djIyEhIJBJERkaqvS4i0i0MQERERKRzGICIiIhI5zAAERFpWG5urqZLINI5DEBEOubTTz+FRCLBlStX8Oabb8LKygp2dnYIDQ2FEAI3b97EkCFDYGlpCalUim+//bbKMVJTU/HWW2/BwcEBJiYm6Ny5M3799dcq4zIyMhAcHAwrKytYW1sjKCgIGRkZ1dYVFxeH1157DTY2NjAxMYG3tzf+/PNPleZ448YNTJo0Ce7u7jA1NUWzZs0wbNgwXL9+vdoap02bBldXVxgbG6NFixYYM2YM0tPTFWMKCgrw6aefol27djAxMYGjoyNeeeUVJCQkAHj0tUnVXe8UHBwMc3NzJCQkYPDgwbCwsMAbb7wBAPjf//6HYcOGoWXLljA2NoazszOmTZuG/Pz8an9ew4cPh52dHUxNTeHu7o5PPvkEAHDo0CFIJBLs2LGjyuM2bNgAiUSCEydO1PbHStSoGGi6ACLSjBEjRqBDhw748ssvsWvXLnz++eewsbHBTz/9hOeffx5fffUV1q9fjxkzZqBHjx549tlnAQD5+fno378/4uPjMWXKFLi5ueH3339HcHAwMjIyMHXqVACAEAJDhgzB0aNH8c4776BDhw7YsWMHgoKCqtRy8eJF9O7dG05OTggJCUGTJk2wZcsWDB06FNu2bcPLL79cq7mdOnUKx48fx8iRI9GiRQtcv34dy5cvR//+/XHp0iWYmZkBAHJyctC3b1/ExsZi3Lhx6NatG9LT0/Hnn3/i1q1bsLW1hUwmw0svvYSIiAiMHDkSU6dORXZ2Nvbv34+YmBi0bt261j/7kpIS+Pv7o0+fPli0aJGint9//x15eXl499130axZM5w8eRJLlizBrVu38Pvvvyse/++//6Jv374wNDTExIkT4erqioSEBPz3v//FF198gf79+8PZ2Rnr16+v8rNbv349WrdujV69etW6bqJGRRCRTpk3b54AICZOnKjYVlJSIlq0aCEkEon48ssvFdsfPHggTE1NRVBQkGJbeHi4ACDWrVun2FZUVCR69eolzM3NRVZWlhBCiD/++EMAEF9//bXS8/Tt21cAEKtXr1Zsf+GFF4SXl5coKChQbJPL5cLX11e0bdtWse3QoUMCgDh06NBj55iXl1dl24kTJwQAsXbtWsW2uXPnCgBi+/btVcbL5XIhhBCrVq0SAMTixYsfOeZRdSUmJlaZa1BQkAAgQkJCalR3WFiYkEgk4saNG4ptzz77rLCwsFDaVrkeIYSYNWuWMDY2FhkZGYptqampwsDAQMybN6/K8xDpGn4ERqSjxo8fr/izvr4+vL29IYTAW2+9pdhubW0Nd3d3XLt2TbFt9+7dkEqlGDVqlGKboaEh3n//feTk5ODw4cOKcQYGBnj33XeVnue9995TquP+/fs4ePAghg8fjuzsbKSnpyM9PR337t2Dv78/rl69itu3b9dqbqampoo/FxcX4969e2jTpg2sra1x5swZxb5t27ahc+fO1Z5hkkgkijG2trZV6q48RhWVfy7V1Z2bm4v09HT4+vpCCIGzZ88CANLS0nDkyBGMGzcOLVu2fGQ9Y8aMQWFhIbZu3arYtnnzZpSUlODNN99UuW6ixoIBiEhHPfzmaWVlBRMTE9ja2lbZ/uDBA8X3N27cQNu2baGnp/zPR4cOHRT7y//r6OgIc3NzpXHu7u5K38fHx0MIgdDQUNjZ2Sl9zZs3D0DpNUe1kZ+fj7lz58LZ2RnGxsawtbWFnZ0dMjIykJmZqRiXkJAAT0/Pxx4rISEB7u7uMDBQ3xUDBgYGaNGiRZXtSUlJCA4Oho2NDczNzWFnZ4d+/foBgKLu8jD6pLrbt2+PHj16YP369Ypt69evxzPPPIM2bdqoaypEDRavASLSUfr6+jXaBpRez1NX5HI5AGDGjBnw9/evdkxt37Dfe+89rF69Gh988AF69eoFKysrSCQSjBw5UvF86vSoM0Eymaza7cbGxlUCpEwmw4ABA3D//n3MnDkT7du3R5MmTXD79m0EBwerVPeYMWMwdepU3Lp1C4WFhfjnn3+wdOnSWh+HqDFiACKiWnFxccG///4LuVyu9CYeFxen2F/+34iICOTk5CidBbp8+bLS8Vq1agWg9GM0Pz8/tdS4detWBAUFKXWwFRQUVOlAa926NWJiYh57rNatWyMqKgrFxcUwNDSsdkzTpk0BoMrxy8+G1cSFCxdw5coV/PrrrxgzZoxi+/79+5XGlf+8nlQ3AIwcORLTp0/Hxo0bkZ+fD0NDQ4wYMaLGNRE1ZvwIjIhqZfDgwUhOTsbmzZsV20pKSrBkyRKYm5srPrIZPHgwSkpKsHz5csU4mUyGJUuWKB3P3t4e/fv3x08//YS7d+9Web60tLRa16ivr1/lrNWSJUuqnJF59dVXcf78+Wrbxcsf/+qrryI9Pb3aMyflY1xcXKCvr48jR44o7f/xxx9rVXPlY5b/+fvvv1caZ2dnh2effRarVq1CUlJStfWUs7W1xaBBg7Bu3TqsX78eAQEBVT7iJNJVPANERLUyceJE/PTTTwgODkZ0dDRcXV2xdetWHDt2DOHh4bCwsAAABAYGonfv3ggJCcH169fRsWNHbN++XekanHLLli1Dnz594OXlhQkTJqBVq1ZISUnBiRMncOvWLZw/f75WNb700kv47bffYGVlhY4dO+LEiRM4cOAAmjVrpjTuo48+wtatWzFs2DCMGzcO3bt3x/379/Hnn39ixYoV6Ny5M8aMGYO1a9di+vTpOHnyJPr27Yvc3FwcOHAAkyZNwpAhQ2BlZYVhw4ZhyZIlkEgkaN26Nf76669aXbvUvn17tG7dGjNmzMDt27dhaWmJbdu2KV1/Ve6HH35Anz590K1bN0ycOBFubm64fv06du3ahXPnzimNHTNmDF577TUAwIIFC2r1cyRq1DTVfkZEmlHeBp+Wlqa0PSgoSDRp0qTK+H79+gkPDw+lbSkpKWLs2LHC1tZWGBkZCS8vL6VW73L37t0To0ePFpaWlsLKykqMHj1anD17tkpruBBCJCQkiDFjxgipVCoMDQ2Fk5OTeOmll8TWrVsVY2raBv/gwQNFfebm5sLf31/ExcUJFxcXpZb+8hqnTJkinJychJGRkWjRooUICgoS6enpijF5eXnik08+EW5ubsLQ0FBIpVLx2muviYSEBMWYtLQ08eqrrwozMzPRtGlT8fbbb4uYmJhq2+Cr+zkLIcSlS5eEn5+fMDc3F7a2tmLChAni/Pnz1f68YmJixMsvvyysra2FiYmJcHd3F6GhoVWOWVhYKJo2bSqsrKxEfn7+Y39uRLpEIkQdXt1IREQaVVJSgubNmyMwMBArV67UdDlEWoPXABERNWJ//PEH0tLSlC6sJiKAZ4CIiBqhqKgo/Pvvv1iwYAFsbW2VbgBJRDwDRETUKC1fvhzvvvsu7O3tsXbtWk2XQ6R1eAaIiIiIdA7PABEREZHOYQAiIiIincMbIVZDLpfjzp07sLCweKrVnomIiKj+CCGQnZ2N5s2bV1lv72EMQNW4c+cOnJ2dNV0GERERqeDmzZto0aLFY8cwAFWj/Fb+N2/ehKWlpYarISIioprIysqCs7Oz4n38cRiAqlH+sZelpSUDEBERUQNTk8tXeBE0ERER6RwGICI
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"plt.plot(history.history['accuracy'])\n",
"plt.plot(history.history['val_accuracy'])\n",
"plt.title('model accuracy')\n",
"plt.ylabel('accuracy')\n",
"plt.xlabel('epoch')\n",
"plt.legend(['train', 'test'], loc='upper left')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Dodatkowy trening modelu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1233/1233 [==============================] - 273s 222ms/step - loss: 0.2408 - accuracy: 0.9019 - val_loss: 0.3459 - val_accuracy: 0.8605\n",
"Epoch 2/5\n",
"1233/1233 [==============================] - 272s 221ms/step - loss: 0.2180 - accuracy: 0.9105 - val_loss: 0.3498 - val_accuracy: 0.8656\n"
]
}
],
"source": [
"callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=1, restore_best_weights=True)\n",
"history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=5, callbacks=[callback])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 28,
"outputs": [],
"source": [
"model.save(\"lstm_model_v2.keras\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Testowanie i ewaluacja modelu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 31,
"outputs": [],
"source": [
"import tensorflow as tf\n",
"def test_review_text(sentence):\n",
" vectorized = vectorizer(sentence)\n",
" reshaped = tf.reshape(vectorized,shape=(1,200))\n",
" #print(vectorized.shape)\n",
" score = float(model(reshaped))\n",
" score_rounded = round(score)\n",
" print(score)\n",
" if score_rounded==0:\n",
" print(\"Negative review\")\n",
" else:\n",
" print(\"Positive review\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 32,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.02259424328804016\n",
"Negative review\n"
]
}
],
"source": [
"test_review_text(\"A buggy, uninspired mess\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 33,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.066298708319664\n",
"Negative review\n"
]
}
],
"source": [
"test_review_text(\"This game is bad\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 34,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9277510643005371\n",
"Positive review\n"
]
}
],
"source": [
"test_review_text(\"This game destroyed my life\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 35,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.990617036819458\n",
"Positive review\n"
]
}
],
"source": [
"test_review_text(\"Best game I've ever played\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 36,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9053470492362976\n",
"Positive review\n"
]
}
],
"source": [
"test_review_text(\"Fun cooperative play with scalable difficulty. Rapid path to get into a game with friends or open public games. \")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 37,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.3265230357646942\n",
"Negative review\n"
]
}
],
"source": [
"test_review_text(\"Deliriously buggy. Fun if/when it works properly. Wait and see if they actually QA the next few patches before you play.\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 38,
"outputs": [],
"source": [
"test[\"model_predictions\"] = model(np.stack(test[\"vectorized\"].values))"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 39,
"outputs": [],
"source": [
"test[\"model_predictions\"] = test[\"model_predictions\"].apply(lambda x : round(float(x)))"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 40,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.86\n",
"Precision: 0.97\n",
"Recall: 0.86\n",
"F1 Score: 0.91\n"
]
}
],
"source": [
"def get_metrics():\n",
" df = test\n",
" predictions = df[\"model_predictions\"].to_numpy()\n",
" true_values = df[\"review_score\"].to_numpy()\n",
" accuracy = np.sum(np.rint(predictions) == true_values)/len(true_values)\n",
" TN_count = len(df.query(\"`review_score`==0 and `model_predictions`==0\").index)\n",
" TP_count = len(df.query(\"`review_score`==1 and `model_predictions`==1\").index)\n",
" FP_count = len(df.query(\"`review_score`==0 and `model_predictions`==1\").index)\n",
" FN_count = len(df.query(\"`review_score`==1 and `model_predictions`==0\").index)\n",
" precision = TP_count/(TP_count+FP_count)\n",
" recall = TP_count/(TP_count+FN_count)\n",
" F1_score = (2*precision*recall)/(precision+recall)\n",
" print(f\"Accuracy: {accuracy:.2f}\")\n",
" print(f\"Precision: {precision:.2f}\")\n",
" print(f\"Recall: {recall:.2f}\")\n",
" print(f\"F1 Score: {F1_score:.2f}\")\n",
"get_metrics()"
],
"metadata": {
"collapsed": false
}
},
2024-06-03 16:32:11 +02:00
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}