ium_452487/train.ipynb

608 lines
83 KiB
Plaintext
Raw Normal View History

2024-04-14 17:30:10 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import zipfile\n",
"with zipfile.ZipFile(\"dataset_cleaned.zip\", 'r') as zip_ref:\n",
" zip_ref.extractall(\"dataset_cleaned_extracted\")\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
2024-04-14 17:30:10 +02:00
"outputs": [],
"source": [
"import pandas as pd\n",
"train = pd.read_csv(\"dataset_cleaned_extracted/train.csv\")\n",
"test = pd.read_csv(\"dataset_cleaned_extracted/test.csv\")\n",
"valid = pd.read_csv(\"dataset_cleaned_extracted/valid.csv\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays',\n",
" 'PhysicalActivities', 'SleepHours', 'RemovedTeeth', 'HadHeartAttack',\n",
" 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
" 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms',\n",
" 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver',\n",
" 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos'],\n",
" dtype='object')\n"
]
}
],
"source": [
"num_columns = train.select_dtypes(['float64']).columns\n",
"print(num_columns)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"data": {
"text/plain": "36"
},
"execution_count": 3,
2024-04-14 17:30:10 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(num_columns)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 4,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays', 'PhysicalActivities', 'SleepHours', 'RemovedTeeth', 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD', 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis', 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty', 'DifficultyConcentrating', 'DifficultyWalking', 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus', 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms', 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver', 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos']\n"
]
}
],
"source": [
"x_columns = ['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays',\n",
" 'PhysicalActivities', 'SleepHours', 'RemovedTeeth',\n",
" 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
" 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms',\n",
" 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver',\n",
" 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos']\n",
"print(x_columns)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 5,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"data": {
"text/plain": "35"
},
"execution_count": 5,
2024-04-14 17:30:10 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(x_columns)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
2024-04-14 17:30:10 +02:00
"outputs": [],
"source": [
"y_column = 'HadHeartAttack'"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
2024-04-14 17:30:10 +02:00
"outputs": [],
"source": [
"train_x = train[x_columns]\n",
"train_y = train[y_column]\n",
"\n",
"test_x = test[x_columns]\n",
"test_y = test[y_column]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 8,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 676617 entries, 0 to 676616\n",
"Data columns (total 41 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Unnamed: 0 676617 non-null int64 \n",
" 1 State 676617 non-null object \n",
" 2 Male 676617 non-null float64\n",
" 3 GeneralHealth 676617 non-null float64\n",
" 4 PhysicalHealthDays 676617 non-null float64\n",
" 5 MentalHealthDays 676617 non-null float64\n",
" 6 LastCheckupTime 676617 non-null object \n",
" 7 PhysicalActivities 676617 non-null float64\n",
" 8 SleepHours 676617 non-null float64\n",
" 9 RemovedTeeth 676617 non-null float64\n",
" 10 HadHeartAttack 676617 non-null float64\n",
" 11 HadAngina 676617 non-null float64\n",
" 12 HadStroke 676617 non-null float64\n",
" 13 HadAsthma 676617 non-null float64\n",
" 14 HadSkinCancer 676617 non-null float64\n",
" 15 HadCOPD 676617 non-null float64\n",
" 16 HadDepressiveDisorder 676617 non-null float64\n",
" 17 HadKidneyDisease 676617 non-null float64\n",
" 18 HadArthritis 676617 non-null float64\n",
" 19 HadDiabetes 676617 non-null float64\n",
" 20 DeafOrHardOfHearing 676617 non-null float64\n",
" 21 BlindOrVisionDifficulty 676617 non-null float64\n",
" 22 DifficultyConcentrating 676617 non-null float64\n",
" 23 DifficultyWalking 676617 non-null float64\n",
" 24 DifficultyDressingBathing 676617 non-null float64\n",
" 25 DifficultyErrands 676617 non-null float64\n",
" 26 SmokerStatus 676617 non-null float64\n",
" 27 ECigaretteUsage 676617 non-null float64\n",
" 28 ChestScan 676617 non-null float64\n",
" 29 RaceEthnicityCategory 676617 non-null object \n",
" 30 AgeCategory 676617 non-null object \n",
" 31 HeightInMeters 676617 non-null float64\n",
" 32 WeightInKilograms 676617 non-null float64\n",
" 33 BMI 676617 non-null float64\n",
" 34 AlcoholDrinkers 676617 non-null float64\n",
" 35 HIVTesting 676617 non-null float64\n",
" 36 FluVaxLast12 676617 non-null float64\n",
" 37 PneumoVaxEver 676617 non-null float64\n",
" 38 TetanusLast10Tdap 676617 non-null float64\n",
" 39 HighRiskLastYear 676617 non-null float64\n",
" 40 CovidPos 676617 non-null float64\n",
"dtypes: float64(36), int64(1), object(4)\n",
"memory usage: 211.6+ MB\n"
]
}
],
"source": [
"train.info()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Definiowanie modelu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 9,
2024-04-14 17:30:10 +02:00
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from keras import layers\n",
"from keras.optimizers import Adam\n",
"def create_model():\n",
" inputs = keras.Input(shape=(35,))\n",
" dense1 = layers.Dense(64, activation=\"relu\")(inputs)\n",
" dropout1 = layers.Dropout(0.2)(dense1)\n",
" dense2 = layers.Dense(32, activation=\"relu\")(dropout1)\n",
" dropout2 = layers.Dropout(0.2)(dense2)\n",
" output = layers.Dense(1, activation=\"sigmoid\")(dropout2)\n",
" model = keras.Model(inputs=inputs, outputs=output)\n",
"\n",
" model.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])\n",
" return model\n",
"\n",
"model = create_model()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 10,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
2024-04-14 17:30:10 +02:00
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 35)] 0 \n",
2024-04-14 17:30:10 +02:00
" \n",
" dense (Dense) (None, 64) 2304 \n",
2024-04-14 17:30:10 +02:00
" \n",
" dropout (Dropout) (None, 64) 0 \n",
2024-04-14 17:30:10 +02:00
" \n",
" dense_1 (Dense) (None, 32) 2080 \n",
2024-04-14 17:30:10 +02:00
" \n",
" dropout_1 (Dropout) (None, 32) 0 \n",
2024-04-14 17:30:10 +02:00
" \n",
" dense_2 (Dense) (None, 1) 33 \n",
2024-04-14 17:30:10 +02:00
" \n",
"=================================================================\n",
"Total params: 4,417\n",
"Trainable params: 4,417\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Trenowanie modelu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 11,
2024-04-14 17:30:10 +02:00
"outputs": [],
"source": [
"# Early stopping dla regularyzacji\n",
2024-04-14 17:30:10 +02:00
"callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=3, restore_best_weights=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 12,
2024-04-14 17:30:10 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/11\n",
"21145/21145 [==============================] - 22s 994us/step - loss: 0.4334 - accuracy: 0.7998 - val_loss: 0.3714 - val_accuracy: 0.8448\n",
"Epoch 2/11\n",
"21145/21145 [==============================] - 21s 972us/step - loss: 0.4257 - accuracy: 0.8038 - val_loss: 0.4273 - val_accuracy: 0.8249\n",
"Epoch 3/11\n",
"21145/21145 [==============================] - 21s 992us/step - loss: 0.4224 - accuracy: 0.8056 - val_loss: 0.4245 - val_accuracy: 0.8219\n",
"Epoch 4/11\n",
"21145/21145 [==============================] - 20s 962us/step - loss: 0.4201 - accuracy: 0.8074 - val_loss: 0.4108 - val_accuracy: 0.8234\n"
2024-04-14 17:30:10 +02:00
]
}
],
"source": [
"history = model.fit(train_x, train_y, validation_data=(test_x, test_y), epochs=11, callbacks=[callback])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Historia treningu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"##### Loss"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [
{
"data": {
"text/plain": "<matplotlib.legend.Legend at 0x226e7b95760>"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAHFCAYAAAD2eiPWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAABXEElEQVR4nO3dd3wUdf7H8deW9E0PQVSqCIRQjEFRiZ4NjiIKqHggxYKgJzYQFFRA8cADPO/nYUGKAnKccDQ9KxyWO2wYpQocTQRjIEASUjfZ3fn9EbJkSZBNSLLJ5v18PPLQnZ2d/cwnA779zndmTIZhGIiIiIj4MbOvCxARERGpaQo8IiIi4vcUeERERMTvKfCIiIiI31PgEREREb+nwCMiIiJ+T4FHRERE/J4Cj4iIiPg9BR4RERHxewo8IlKrDh06RNu2bVm5cmWNfkZEpCwFHhEREfF7CjwiIiLi9xR4RBq466+/ntmzZzNt2jS6du1KUlISY8eOJS8vjzfeeINrrrmG5ORkHnroITIzM92fczqdLFmyhL59+9KpUyeuvfZaZs2ahd1u99j+J598ws0330ynTp3o378/O3fuLFdDVlYWkyZN4qqrrqJjx44MHDiQr776qlL74XQ6eeONN7jpppvo1KkTl1xyCX/4wx/4+uuvPdbbtGkT99xzD5deeilXXHEFY8aM4fDhw+73jxw5whNPPMGVV15JUlISQ4YM4YcffgDOfGrtySef5Prrr3e/Hjp0KI8//jgPP/wwl1xyCXfffbf78+PHjyclJYXExESuvPJKxo8f79FXwzB466236NWrF506daJ79+7Mnz8fwzD47LPPaNu2Lf/97389vv+7776jbdu2pKamVqpnIg2J1dcFiIjvLViwgG7duvHSSy+xbds2XnzxRbZv3058fDxTp07l0KFD/OlPfyIuLo7JkycDMGnSJNasWcN9991Hly5d+PHHH3nllVfYsWMH8+bNw2QysX79eh5++GH69u3LuHHj2LFjB+PGjfP4brvdzvDhwzl69CiPPfYY8fHxrFixghEjRjBv3jyuvPJKr/Zh1qxZLF26lLFjx9K2bVsOHz7MK6+8wiOPPMJnn31GSEgIP/74I0OGDKFz587MmDEDp9PJiy++yL333svq1aux2+0MGjQIp9PJuHHjaNy4MQsWLOCee+5h1apVWK3e/5X54YcfcvPNN/Paa6/hcrkoKChg2LBhREdHM3nyZMLDw/nhhx+YPXs2wcHBPPfccwDMmDGDhQsXcvfdd9OtWze2bt3KrFmzcDgcjBgxgvj4eNasWUNKSor7u1avXk2LFi1ITk72uj6RhkaBR0Sw2Wy89NJLWK1WrrrqKlatWsXhw4dZvnw54eHhAPznP//h+++/B2DPnj3885//ZOzYsYwcORKAbt26ER8fz/jx4/niiy/43e9+xyuvvEKnTp2YOXMmAFdffTUAL774ovu716xZw86dO1m2bBmdO3cG4JprrmHo0KHMmjWLFStWeLUPR44c4bHHHmPo0KHuZUFBQTz00EPs2rWLSy65hNdff52oqCgWLFhAUFAQAPHx8YwdO5bdu3eTmprKL7/8wqpVq0hISADg0ksvpV+/fmzcuNHr8AUQEBDAs88+S2BgIAA7duzgvPPO489//jNNmzYF4IorrmDz5s18++23AJw4cYJFixYxZMgQdzC86qqryMjIYOPGjYwaNYr+/fuzePFi8vLyCAsLo7CwkA8//ND9exCRiumUlojQqVMnj9GLuLg4WrZs6Q47AFFRUeTk5AC4/wPdp08fj+306dMHi8XCN998Q2FhIdu3b+e6667zWKdXr14er7/66isaNWpEYmIiDocDh8OB0+nkuuuuY9u2bWRnZ3u1Dy+++CLDhw/n+PHjfPfdd6xYsYJ3330XgKKiIgBSU1O55ppr3GEHICkpifXr15OQkEBqaioXXnihO+wAhISE8PHHH3P77bd7VUepVq1aucMOQEJCAn//+9+54IIL+Omnn/j888+ZP38++/btc9e3adMmHA4HPXr08NjW008/zbx58wC49dZbyc/PZ+3atQCsXbuW/Px8+vXrV6n6RBoajfCICDabrdyy0NDQM65fGkIaNWrksdxqtRIdHU1OTg7Z2dkYhkF0dLTHOvHx8R6vs7KyyMjIIDExscLvysjIIDg4+Kz7sHXrVp599lm2bt1KSEgIrVu35vzzzwdK5sWUfldsbOwZt3G29ysjLCys3LI333yT119/naysLOLi4ujQoQMhISHuIJmVlQVATEzMGbfbvHlzLr/8clavXk2/fv1YvXo1V111FY0bN66WukX8lQKPiFRaZGQkUBJGLrjgAvfy4uJiMjMziY6OJioqCrPZzNGjRz0+W/of9VLh4eG0aNGCWbNmVfhdF154YbltnC43N5cRI0bQtm1b3n//fVq1aoXZbObzzz/n448/9viu48ePl/v8559/TkJCAuHh4Rw6dKjc+99//z2RkZHu4OV0Oj3ez8/P/836AN577z1eeOEFxo0bx4ABA9yh5pFHHmHr1q0AREREAHD8+HFatWrl/mxaWho///wzycnJBAQEcOuttzJx4kT27t3LV199dcbeicgpOqUlIpV2+eWXA/D+++97LH///fdxOp0kJycTFBREUlISn3zyiXuEBWD9+vXltvXrr78SGxtLx44d3T8bNmxg3rx5WCyWs9azb98+srKyGDZsGK1bt8ZsLvmr7YsvvgDA5XIB0KVLFzZs2OA+hQTw448/MnLkSLZv306XLl04ePAgu3fvdr9vt9t56KGH+Oc//+keCSt7VVdxcTFbtmw5a42pqalEREQwYsQId9jJy8sjNTXVXV+nTp0ICAjg008/9fjsggULGDNmjLsXv//97wkJCWHKlCmEhYVx4403nvX7RRo6jfCISKW1bt2a/v378/LLL1NQUMBll13Gjh07mD17Nl27dnVPTh4zZgzDhw9n9OjR3HHHHezfv5/XX3/dY1sDBgzg7bff5u677+b++++nSZMmfPnll8ydO5chQ4YQEBBw1npatmyJzWbj9ddfx2q1YrVa+fjjj/nnP/8JQEFBAQB//OMfueOOOxg1ahTDhg2jsLCQv/71r3Tq1Ilu3bpRVFTE4sWLeeCBB3j44YeJjo5m0aJFFBcXM3jwYCIjI0lKSmLx4sU0b96cyMhIFi1aRGFh4W+eAoSSMLN06VJeeOEFrrvuOo4cOcL8+fM5evSoe8QsJiaGYcOG8dZbbxEYGMjll1/O5s2bWbp0KePHj3cHuZCQEPr06cM777zDoEGDPOYKiUjFNMIjIlXypz/9iQcffJD33nuPkSNHsmTJEoYNG8bcuXPd/2Hu0qULc+fO5fDhw4wePZp33nmHadOmeWwnNDSUJUuWkJyczMyZM7nvvvv45JNPGDt2LBMmTPCqlvDwcF599VUMw+CRRx5h/PjxpKWl8fbbbxMWFsZ3330HQPv27Vm8eDEOh4NHH32U559/nuTkZObMmUNgYCA2m423336bzp07M3XqVB599FFcLheLFi1yX1n1wgsv0KFDB55++mkmTJhAYmIiw4cPP2uN/fv358EHH+TDDz/kvvvu4+WXX6ZLly4899xzZGVlsXfvXgDGjRvHmDFj+Ne//sXIkSNZs2YNzzzzTLnvuPbaa4GSwCgiZ2cyyo41i4hIvTB58mQ2b97M6tWrfV2KSL2gU1oiIvXIokWL2LdvH8uWLXPf30hEzk6BR0SkHvnuu+/4z3/+w/Dhw7npppt8XY5IvaFTWiIiIuL3NGlZRERE/J4Cj4iIiPg9BR4RERHxewo8IiIi4vcUeERERMTv6bL0Mo4dy6G6r1kzmSA2NrxGtu1v1CvvqVfeU6+8p15VjvrlvZrqVel2vaHAU4ZhUGMHbU1u29+oV95Tr7ynXnlPvaoc9ct7vuyVTmmJiIiI31PgEREREb+nwCMiIiJ+T3N4KsHlcuF0Oir1GZMJCgsLKS4uajDneK3WAEwmk6/LEBERcVPg8YJhGJw4cZyCgtwqff74cTMul6uaq6q7TCYzsbHnYbUG+LoUERERQIHHK6Vhx2a
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"plt.plot(history.history['accuracy'])\n",
"plt.plot(history.history['val_accuracy'])\n",
"plt.title('model accuracy')\n",
"plt.ylabel('accuracy')\n",
"plt.xlabel('epoch')\n",
"plt.legend(['train', 'test'], loc='upper left')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Test loss < Train loss ze względu na warstwy dropout"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"##### Accuracy"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAHFCAYAAAD2eiPWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAABo7UlEQVR4nO3dd3wUdf7H8deWbEsvEKp0kIuUUFWwoWBDBc9TLIB64qlg/1nAE0E99cSzonKinKicKFbsHrY7FRQQiCAoRZBuQhJIstlNdnd+f4QshIBsYjaTbN7PxyMPspPZ2c98GODNd74zYzEMw0BEREQkhlnNLkBEREQk2hR4REREJOYp8IiIiEjMU+ARERGRmKfAIyIiIjFPgUdERERingKPiIiIxDwFHhEREYl5CjwiIiIS8xR4RKTR2rJlC926deONN96o0/cMGTKE22+/vS5KFJEGQoFHREREYp4Cj4iIiMQ8BR4RqTNDhgxh+vTp3HfffQwcOJDs7GxuvvlmSkpKeOaZZzj++OPp27cv1157LQUFBeH3BYNB5syZw1lnnUXPnj058cQTeeihh/D7/VW2//HHH3P22WfTs2dPRo4cyZo1a6rVUFhYyOTJkzn22GPp0aMH559/PgsXLvxd+1VUVMT999/PKaecQo8ePRg+fDivvfZalXVWrlzJ2LFj6du3L9nZ2Vx66aUsX748/PP8/HxuvvlmBg0aRI8ePTjnnHN46623flddIhI5u9kFiEhsmTVrFoMGDeKRRx5h5cqV/OMf/2DVqlU0b96ce+65hy1btvC3v/2NjIwM7rrrLgAmT57M22+/zbhx4+jXrx8//PADTz75JKtXr+bZZ5/FYrHw6aefct1113HWWWdxyy23sHr1am655ZYqn+33+xk7dix5eXnceOONNG/enNdff50rrriCZ599lmOOOabG++Pz+bjooovYtWsX1113Ha1bt2bBggXccccd5OXlcdVVV1FcXMwVV1zB0UcfzRNPPEFZWRlPP/00f/7zn/n8889JTEzklltuYdeuXUydOpWEhATefvttbrvtNlq0aMHRRx9dJ70XkUNT4BGROpWQkMAjjzyC3W7n2GOP5c0332Tnzp3MmzePxMREAP73v//x3XffAbBu3Tpee+01br75Zq688koABg0aRPPmzbn11lv573//ywknnMCTTz5Jz549mTZtGgDHHXccAP/4xz/Cn/3222+zZs0aXn31VXr16gXA8ccfz+jRo3nooYd4/fXXa7w/b7zxBj/99BNz584lOzs7/NmBQICnnnqKUaNGsXHjRgoKChgzZgx9+vQBoGPHjrzyyiuUlJSQmJjIt99+y/jx4znllFMAGDBgACkpKTgcjhrXJCI1p1NaIlKnevbsid2+7/9SGRkZdOjQIRx2AFJSUigqKgLg22+/BeDMM8+ssp0zzzwTm83GN998g8/nY9WqVZx00klV1jn99NOrvF64cCHNmjUjKyuLQCBAIBAgGAxy0kknsXLlSnbv3l3j/fn2229p3bp1OOxUOvvss/H7/axYsYIuXbqQlpbGVVddxeTJk/nPf/5DRkYGt9xyCy1atABg4MCBPPHEE1x33XXMmzePvLw8brvttnBAEpHo0giPiNSphISEass8Hs8h168MIc2aNauy3G63k5qaSlFREbt378YwDFJTU6us07x58yqvCwsLyc3NJSsr66CflZubi8vlimg/9q/vwNqgIsgB7Nmzh/j4eObMmcPTTz/NBx98wCuvvILL5eKcc87hr3/9Kw6Hg0ceeYQZM2bwwQcf8NFHH2G1Wjn22GO5++67ad26dY1qEpGaU+AREVMlJycDFWFk/3/4y8vLKSgoIDU1lZSUFKxWK3l5eVXeW1hYWOV1YmIi7du356GHHjroZ7Vp06baNiKpb9OmTdWW5+bmAoRDWMeOHZk2bRrBYJCcnBzefvttXn75ZY444giuuOKK8DyeW265hQ0bNvDJJ5/w1FNPMXXqVJ555pka1SQiNadTWiJiqgEDBgDw3nvvVVn+3nvvEQwG6du3L06nk+zsbD7++GMMwwiv8+mnn1bb1vbt20lPT6dHjx7hr6+++opnn30Wm81W4/r69+/P1q1bWbZsWZXl8+fPJy4ujp49e/Lhhx9y9NFHk5ubi81mIzs7mylTppCUlMS2bdvYunUrJ5xwAh9++CFQEY7GjRvHsccey7Zt22pck4jUnEZ4RMRUnTt3ZuTIkTz++OOUlpbSv39/Vq9ezfTp0xk4cGB4cvJNN93E2LFjmTBhAhdccAE///wzM2bMqLKtc889l5deeonLLruMq666ipYtW/L1118zc+ZMLrnkEuLi4mpc37nnnsu///1vxo8fz3XXXUebNm349NNPef3115kwYQJJSUn06dOHUCjE+PHjufLKK4mPj+eDDz6gqKiIYcOG0bp1a1q0aMG9995LcXExRxxxBCtXruSLL77gL3/5S530UUR+mwKPiJjub3/7G+3ateP1119n5syZNG/enDFjxnDNNddgtVYMRPfr14+ZM2fy8MMPM2HCBNq0acN9993HVVddFd6Ox+Nhzpw5/OMf/2DatGkUFRXRunVrbr75Zi6//PJa1eZ2u3nxxRf5xz/+wWOPPUZxcTEdO3bkb3/7G+eddx5QMZfo2Wef5bHHHuOOO+6gtLSULl268MQTT4QvOZ8+fToPP/wwjz32GAUFBbRs2ZIJEyaEr0wTkeiyGPuPD4uIiIjEIM3hERERkZinwCMiIiIxT4FHREREYp4Cj4iIiMQ8BR4RERGJeQo8IiIiEvMUeERERCTmKfCIiIhIzNOdlveza1cRdX0bRosF0tMTo7LtWKNeRU69ipx6VTPqV+TUq8hFq1eV242EAs9+DIOoHbTR3HasUa8ip15FTr2qGfUrcupV5MzslU5piYiISMxT4BEREZGYp8AjIiIiMU9zeGogFAoRDAZq9B6LBXw+H+XlZU3mHK/dHofFYjG7DBERkTAFnggYhsGePfmUlhbX6v35+VZCoVAdV9VwWSxW0tNbYLfHmV2KiIgIoMATkcqwk5CQisPhrPHohc1mIRhsGsM7hhGisHAXu3fnk5bWXCM9IiLSICjwHEYoFAyHnYSEpFptw263Egg0nRGexMQUdu/OIxQKYrPpEBMREfNp0vJhBINBABwOp8mVNB6VIacpncYTEZGGTYEnQjo1Ezn1SkREGhoFHhEREYl5CjwxbO3aH/n++xW1eu95553F+++/U8cViYiImEOBJ4ZNmnQLmzf/Uqv3zpz5AiefPLSOKxIRETGHLqGJMsMwCJl0RbrxO+50mJqaWoeViIiImEuBJ4oMw2B9XgmBkIE7zkaC0068w4bTbo36xN4JE65kx47t3HffVGbNegaAo48+lv/850NGj76MCy64mBkznuCTT/5DQUE+zZo1Z/ToyzjnnHOBilNal19+JWeccRYTJlxJ//4DWbFiGcuXL6N580xuvPEWBg48Jqr7ICIiUld0SquWDMOgtDx42K9AqGK9fG8ZvxR4Wb2ziJXb97Ahr4Qde3wU+csj2k5pebBGIzb33TeN5s0zue66m7n++pvZsWM7ZWVlPPfcS5xyymm8+OK/+PrrL7n33gf5979f5/TTh/PIIw+Sn7/roNt74YVZnHLKqbz44it06dKVv//9Xl12LiIijYZGeGrBMAyumLuCnG176vVze7VKYuaoXhGNDiUlJWO1WklISCA+PgGAiy8eS5s2bQHo3LkrffsO4KijegAwevRl/OtfM9m8+RfS0tKrbe+YYwZzxhlnATB27J+59NILyc/fRUZGs7raPRERkahR4KmlxninmRYtWoa/P/74E1m8eBFPPPEIv/yykZ9+WgPsu9Higdq2PSL8fXx8PACBQM0epCoiImIWBZ5asFgszBzVC1+Ej4uw26wEggdftzwYosQfpKQsQElZkNABp60cNivxDhvxTjtp7t/3FHKnc9/dop955ineeectzjjjLE477Uxuvvl2zjvvrEPvg736ofJ7JkWLiIjUJwWeWrJYLLjjbBGtW/EsrYMHFXecjSRXxVPFDcPAVx6iuCx
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history['loss'])\n",
"plt.plot(history.history['val_loss'])\n",
"plt.title('model loss')\n",
"plt.ylabel('loss')\n",
"plt.xlabel('epoch')\n",
"plt.legend(['train', 'test'], loc='upper left')\n",
"plt.show()"
2024-04-14 17:30:10 +02:00
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Zapisywanie modelu"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"model.save(\"model_v1.keras\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Testowanie na zbiorze walidacyjnym"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"valid_x = valid[x_columns]\n",
"valid_y = valid[y_column]"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 36,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1392/1392 [==============================] - 1s 569us/step\n",
"Poprawność na zbiorze walidacyjnym: 86.15%\n"
]
}
],
"source": [
"import numpy as np\n",
"predictions = model.predict(valid_x)[:,0]\n",
"true_answers = valid_y.to_numpy()\n",
"validation_accuracy = np.sum(np.rint(predictions) == true_answers)/len(true_answers)\n",
"print(f\"Poprawność na zbiorze walidacyjnym: {validation_accuracy:.2%}\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 37,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.08692811 0.12067404 0.31880796 0.64843357 0.15188715 0.06517262\n",
" 0.03407578 0.49311596 0.00781232 0.2089161 0.46056542 0.45341685\n",
" 0.4294767 0.25619727 0.20345858 0.2302334 0.38631877 0.36519188\n",
" 0.04014764 0.23888215 0.27519897 0.08928084 0.05204074 0.42043713\n",
" 0.19055638 0.29787344 0.23068897 0.88435644 0.03139259 0.95048493\n",
" 0.2457671 0.5858893 0.02678488 0.06240147 0.52132165 0.01431455\n",
" 0.02444405 0.07804424 0.11274771 0.12714393 0.35450152 0.01294624\n",
" 0.190797 0.07512036 0.48486376 0.06140704 0.9019506 0.08810509\n",
" 0.61831665 0.15642735 0.03310075 0.04532438 0.10763614 0.4277772\n",
" 0.20325996 0.8980398 0.7491019 0.38502344 0.03970775 0.0401529\n",
" 0.03046079 0.10123587 0.04993626 0.05702 0.18049946 0.1223311\n",
" 0.731555 0.40104443 0.18443953 0.1265702 0.07467585 0.03895461\n",
" 0.35271063 0.38039213 0.4450048 0.03670818 0.05534125 0.91664517\n",
" 0.413391 0.12545326 0.11306539 0.4350903 0.48778924 0.40804324\n",
" 0.33885244 0.21948677 0.01242744 0.02531701 0.6693964 0.15393472\n",
" 0.9307252 0.09181138 0.05571133 0.1261858 0.02687709 0.27069062\n",
" 0.22613294 0.20686075 0.47390068 0.40349996]\n"
]
}
],
"source": [
"print(predictions[:100])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 38,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.\n",
" 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0.]\n"
]
}
],
"source": [
"print(np.rint(predictions)[:100])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 39,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0.]\n"
]
}
],
"source": [
"print(true_answers[:100])"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}