Upload files to "/"
This commit is contained in:
parent
16282da61d
commit
cf1fcab706
519
train.ipynb
Normal file
519
train.ipynb
Normal file
@ -0,0 +1,519 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import zipfile\n",
|
||||
"with zipfile.ZipFile(\"dataset_cleaned.zip\", 'r') as zip_ref:\n",
|
||||
" zip_ref.extractall(\"dataset_cleaned_extracted\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"# W pobranym zbiorze danych jest kilka podzbiorów więc celowo otwieram ten z NaN, żeby manualnie go oczyścić dla praktyki\n",
|
||||
"train = pd.read_csv(\"dataset_cleaned_extracted/train.csv\")\n",
|
||||
"test = pd.read_csv(\"dataset_cleaned_extracted/test.csv\")\n",
|
||||
"valid = pd.read_csv(\"dataset_cleaned_extracted/valid.csv\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Index(['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays',\n",
|
||||
" 'PhysicalActivities', 'SleepHours', 'RemovedTeeth', 'HadHeartAttack',\n",
|
||||
" 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
|
||||
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
|
||||
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
|
||||
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
|
||||
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
|
||||
" 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms',\n",
|
||||
" 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver',\n",
|
||||
" 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos'],\n",
|
||||
" dtype='object')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"num_columns = train.select_dtypes(['float64']).columns\n",
|
||||
"print(num_columns)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "36"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(num_columns)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays', 'PhysicalActivities', 'SleepHours', 'RemovedTeeth', 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD', 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis', 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty', 'DifficultyConcentrating', 'DifficultyWalking', 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus', 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms', 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver', 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x_columns = ['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays',\n",
|
||||
" 'PhysicalActivities', 'SleepHours', 'RemovedTeeth',\n",
|
||||
" 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
|
||||
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
|
||||
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
|
||||
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
|
||||
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
|
||||
" 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms',\n",
|
||||
" 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver',\n",
|
||||
" 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos']\n",
|
||||
"print(x_columns)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "35"
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(x_columns)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_column = 'HadHeartAttack'"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_x = train[x_columns]\n",
|
||||
"train_y = train[y_column]\n",
|
||||
"\n",
|
||||
"test_x = test[x_columns]\n",
|
||||
"test_y = test[y_column]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"RangeIndex: 676617 entries, 0 to 676616\n",
|
||||
"Data columns (total 41 columns):\n",
|
||||
" # Column Non-Null Count Dtype \n",
|
||||
"--- ------ -------------- ----- \n",
|
||||
" 0 Unnamed: 0 676617 non-null int64 \n",
|
||||
" 1 State 676617 non-null object \n",
|
||||
" 2 Male 676617 non-null float64\n",
|
||||
" 3 GeneralHealth 676617 non-null float64\n",
|
||||
" 4 PhysicalHealthDays 676617 non-null float64\n",
|
||||
" 5 MentalHealthDays 676617 non-null float64\n",
|
||||
" 6 LastCheckupTime 676617 non-null object \n",
|
||||
" 7 PhysicalActivities 676617 non-null float64\n",
|
||||
" 8 SleepHours 676617 non-null float64\n",
|
||||
" 9 RemovedTeeth 676617 non-null float64\n",
|
||||
" 10 HadHeartAttack 676617 non-null float64\n",
|
||||
" 11 HadAngina 676617 non-null float64\n",
|
||||
" 12 HadStroke 676617 non-null float64\n",
|
||||
" 13 HadAsthma 676617 non-null float64\n",
|
||||
" 14 HadSkinCancer 676617 non-null float64\n",
|
||||
" 15 HadCOPD 676617 non-null float64\n",
|
||||
" 16 HadDepressiveDisorder 676617 non-null float64\n",
|
||||
" 17 HadKidneyDisease 676617 non-null float64\n",
|
||||
" 18 HadArthritis 676617 non-null float64\n",
|
||||
" 19 HadDiabetes 676617 non-null float64\n",
|
||||
" 20 DeafOrHardOfHearing 676617 non-null float64\n",
|
||||
" 21 BlindOrVisionDifficulty 676617 non-null float64\n",
|
||||
" 22 DifficultyConcentrating 676617 non-null float64\n",
|
||||
" 23 DifficultyWalking 676617 non-null float64\n",
|
||||
" 24 DifficultyDressingBathing 676617 non-null float64\n",
|
||||
" 25 DifficultyErrands 676617 non-null float64\n",
|
||||
" 26 SmokerStatus 676617 non-null float64\n",
|
||||
" 27 ECigaretteUsage 676617 non-null float64\n",
|
||||
" 28 ChestScan 676617 non-null float64\n",
|
||||
" 29 RaceEthnicityCategory 676617 non-null object \n",
|
||||
" 30 AgeCategory 676617 non-null object \n",
|
||||
" 31 HeightInMeters 676617 non-null float64\n",
|
||||
" 32 WeightInKilograms 676617 non-null float64\n",
|
||||
" 33 BMI 676617 non-null float64\n",
|
||||
" 34 AlcoholDrinkers 676617 non-null float64\n",
|
||||
" 35 HIVTesting 676617 non-null float64\n",
|
||||
" 36 FluVaxLast12 676617 non-null float64\n",
|
||||
" 37 PneumoVaxEver 676617 non-null float64\n",
|
||||
" 38 TetanusLast10Tdap 676617 non-null float64\n",
|
||||
" 39 HighRiskLastYear 676617 non-null float64\n",
|
||||
" 40 CovidPos 676617 non-null float64\n",
|
||||
"dtypes: float64(36), int64(1), object(4)\n",
|
||||
"memory usage: 211.6+ MB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train.info()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Definiowanie modelu"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow import keras\n",
|
||||
"from keras import layers\n",
|
||||
"from keras.optimizers import Adam\n",
|
||||
"def create_model():\n",
|
||||
" inputs = keras.Input(shape=(35,))\n",
|
||||
" dense1 = layers.Dense(64, activation=\"relu\")(inputs)\n",
|
||||
" dropout1 = layers.Dropout(0.2)(dense1)\n",
|
||||
" dense2 = layers.Dense(32, activation=\"relu\")(dropout1)\n",
|
||||
" dropout2 = layers.Dropout(0.2)(dense2)\n",
|
||||
" output = layers.Dense(1, activation=\"sigmoid\")(dropout2)\n",
|
||||
" model = keras.Model(inputs=inputs, outputs=output)\n",
|
||||
"\n",
|
||||
" model.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])\n",
|
||||
" return model\n",
|
||||
"\n",
|
||||
"model = create_model()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model_1\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
" Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
" input_2 (InputLayer) [(None, 35)] 0 \n",
|
||||
" \n",
|
||||
" dense_3 (Dense) (None, 64) 2304 \n",
|
||||
" \n",
|
||||
" dropout_2 (Dropout) (None, 64) 0 \n",
|
||||
" \n",
|
||||
" dense_4 (Dense) (None, 32) 2080 \n",
|
||||
" \n",
|
||||
" dropout_3 (Dropout) (None, 32) 0 \n",
|
||||
" \n",
|
||||
" dense_5 (Dense) (None, 1) 33 \n",
|
||||
" \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 4,417\n",
|
||||
"Trainable params: 4,417\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Trenowanie modelu"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=3, restore_best_weights=True)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/1000\n",
|
||||
"21145/21145 [==============================] - 21s 963us/step - loss: 0.4343 - accuracy: 0.7991 - val_loss: 0.3911 - val_accuracy: 0.8412\n",
|
||||
"Epoch 2/1000\n",
|
||||
"21145/21145 [==============================] - 20s 957us/step - loss: 0.4262 - accuracy: 0.8043 - val_loss: 0.3980 - val_accuracy: 0.8347\n",
|
||||
"Epoch 3/1000\n",
|
||||
"21145/21145 [==============================] - 20s 959us/step - loss: 0.4227 - accuracy: 0.8057 - val_loss: 0.3904 - val_accuracy: 0.8396\n",
|
||||
"Epoch 4/1000\n",
|
||||
"21145/21145 [==============================] - 20s 950us/step - loss: 0.4202 - accuracy: 0.8073 - val_loss: 0.4032 - val_accuracy: 0.8285\n",
|
||||
"Epoch 5/1000\n",
|
||||
"21145/21145 [==============================] - 20s 962us/step - loss: 0.4184 - accuracy: 0.8083 - val_loss: 0.3639 - val_accuracy: 0.8613\n",
|
||||
"Epoch 6/1000\n",
|
||||
"21145/21145 [==============================] - 20s 965us/step - loss: 0.4172 - accuracy: 0.8086 - val_loss: 0.3897 - val_accuracy: 0.8328\n",
|
||||
"Epoch 7/1000\n",
|
||||
"21145/21145 [==============================] - 20s 954us/step - loss: 0.4155 - accuracy: 0.8094 - val_loss: 0.4143 - val_accuracy: 0.8272\n",
|
||||
"Epoch 8/1000\n",
|
||||
"21145/21145 [==============================] - 21s 970us/step - loss: 0.4145 - accuracy: 0.8102 - val_loss: 0.4026 - val_accuracy: 0.8323\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"history = model.fit(train_x, train_y, validation_data=(test_x, test_y), epochs=1000, callbacks=[callback])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Zapisywanie modelu"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"model_v1.keras\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Testowanie na zbiorze walidacyjnym"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"valid_x = valid[x_columns]\n",
|
||||
"valid_y = valid[y_column]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1392/1392 [==============================] - 1s 569us/step\n",
|
||||
"Poprawność na zbiorze walidacyjnym: 86.15%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"predictions = model.predict(valid_x)[:,0]\n",
|
||||
"true_answers = valid_y.to_numpy()\n",
|
||||
"validation_accuracy = np.sum(np.rint(predictions) == true_answers)/len(true_answers)\n",
|
||||
"print(f\"Poprawność na zbiorze walidacyjnym: {validation_accuracy:.2%}\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.08692811 0.12067404 0.31880796 0.64843357 0.15188715 0.06517262\n",
|
||||
" 0.03407578 0.49311596 0.00781232 0.2089161 0.46056542 0.45341685\n",
|
||||
" 0.4294767 0.25619727 0.20345858 0.2302334 0.38631877 0.36519188\n",
|
||||
" 0.04014764 0.23888215 0.27519897 0.08928084 0.05204074 0.42043713\n",
|
||||
" 0.19055638 0.29787344 0.23068897 0.88435644 0.03139259 0.95048493\n",
|
||||
" 0.2457671 0.5858893 0.02678488 0.06240147 0.52132165 0.01431455\n",
|
||||
" 0.02444405 0.07804424 0.11274771 0.12714393 0.35450152 0.01294624\n",
|
||||
" 0.190797 0.07512036 0.48486376 0.06140704 0.9019506 0.08810509\n",
|
||||
" 0.61831665 0.15642735 0.03310075 0.04532438 0.10763614 0.4277772\n",
|
||||
" 0.20325996 0.8980398 0.7491019 0.38502344 0.03970775 0.0401529\n",
|
||||
" 0.03046079 0.10123587 0.04993626 0.05702 0.18049946 0.1223311\n",
|
||||
" 0.731555 0.40104443 0.18443953 0.1265702 0.07467585 0.03895461\n",
|
||||
" 0.35271063 0.38039213 0.4450048 0.03670818 0.05534125 0.91664517\n",
|
||||
" 0.413391 0.12545326 0.11306539 0.4350903 0.48778924 0.40804324\n",
|
||||
" 0.33885244 0.21948677 0.01242744 0.02531701 0.6693964 0.15393472\n",
|
||||
" 0.9307252 0.09181138 0.05571133 0.1261858 0.02687709 0.27069062\n",
|
||||
" 0.22613294 0.20686075 0.47390068 0.40349996]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(predictions[:100])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.\n",
|
||||
" 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(np.rint(predictions)[:100])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(true_answers[:100])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
204
validate.ipynb
Normal file
204
validate.ipynb
Normal file
@ -0,0 +1,204 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import zipfile\n",
|
||||
"with zipfile.ZipFile(\"dataset_cleaned.zip\", 'r') as zip_ref:\n",
|
||||
" zip_ref.extractall(\"dataset_cleaned_extracted\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"valid = pd.read_csv(\"dataset_cleaned_extracted/valid.csv\")\n",
|
||||
"\n",
|
||||
"x_columns = ['Male', 'GeneralHealth', 'PhysicalHealthDays', 'MentalHealthDays',\n",
|
||||
" 'PhysicalActivities', 'SleepHours', 'RemovedTeeth',\n",
|
||||
" 'HadAngina', 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
|
||||
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
|
||||
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
|
||||
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
|
||||
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
|
||||
" 'ECigaretteUsage', 'ChestScan', 'HeightInMeters', 'WeightInKilograms',\n",
|
||||
" 'BMI', 'AlcoholDrinkers', 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver',\n",
|
||||
" 'TetanusLast10Tdap', 'HighRiskLastYear', 'CovidPos']\n",
|
||||
"y_column = 'HadHeartAttack'\n",
|
||||
"\n",
|
||||
"valid_x = valid[x_columns]\n",
|
||||
"valid_y = valid[y_column]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow import keras\n",
|
||||
"model = keras.models.load_model('model_v1.keras')"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1392/1392 [==============================] - 1s 566us/step\n",
|
||||
"Poprawność na zbiorze walidacyjnym: 86.15%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"predictions = model.predict(valid_x)[:,0]\n",
|
||||
"true_answers = valid_y.to_numpy()\n",
|
||||
"validation_accuracy = np.sum(np.rint(predictions) == true_answers)/len(true_answers)\n",
|
||||
"print(f\"Poprawność na zbiorze walidacyjnym: {validation_accuracy:.2%}\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.08692811 0.12067404 0.31880796 0.64843357 0.15188715 0.06517262\n",
|
||||
" 0.03407578 0.49311596 0.00781232 0.2089161 0.46056542 0.45341685\n",
|
||||
" 0.4294767 0.25619727 0.20345858 0.2302334 0.38631877 0.36519188\n",
|
||||
" 0.04014764 0.23888215 0.27519897 0.08928084 0.05204074 0.42043713\n",
|
||||
" 0.19055638 0.29787344 0.23068897 0.88435644 0.03139259 0.95048493\n",
|
||||
" 0.2457671 0.5858893 0.02678488 0.06240147 0.52132165 0.01431455\n",
|
||||
" 0.02444405 0.07804424 0.11274771 0.12714393 0.35450152 0.01294624\n",
|
||||
" 0.190797 0.07512036 0.48486376 0.06140704 0.9019506 0.08810509\n",
|
||||
" 0.61831665 0.15642735 0.03310075 0.04532438 0.10763614 0.4277772\n",
|
||||
" 0.20325996 0.8980398 0.7491019 0.38502344 0.03970775 0.0401529\n",
|
||||
" 0.03046079 0.10123587 0.04993626 0.05702 0.18049946 0.1223311\n",
|
||||
" 0.731555 0.40104443 0.18443953 0.1265702 0.07467585 0.03895461\n",
|
||||
" 0.35271063 0.38039213 0.4450048 0.03670818 0.05534125 0.91664517\n",
|
||||
" 0.413391 0.12545326 0.11306539 0.4350903 0.48778924 0.40804324\n",
|
||||
" 0.33885244 0.21948677 0.01242744 0.02531701 0.6693964 0.15393472\n",
|
||||
" 0.9307252 0.09181138 0.05571133 0.1261858 0.02687709 0.27069062\n",
|
||||
" 0.22613294 0.20686075 0.47390068 0.40349996]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(predictions[:100])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.\n",
|
||||
" 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(np.rint(predictions)[:100])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(true_answers[:100])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.savetxt(\"predictions.txt\",predictions)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.savetxt(\"predictions_two_digits.txt\",predictions, fmt='%1.2f')"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
Loading…
Reference in New Issue
Block a user