fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb
2023-12-12 15:22:01 +01:00

1221 lines
188 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Importy"
]
},
{
"cell_type": "code",
"execution_count": 164,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score\n",
"from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
"from sklearn.metrics import precision_score, recall_score, accuracy_score\n",
"import time"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wczytanie danych"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('data4.csv')"
]
},
{
"cell_type": "code",
"execution_count": 166,
"metadata": {},
"outputs": [],
"source": [
"y = pd.DataFrame(df['isGoal'])\n",
"X = df.drop(['isGoal'], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>match_minute</th>\n",
" <th>match_second</th>\n",
" <th>position_x</th>\n",
" <th>position_y</th>\n",
" <th>play_type</th>\n",
" <th>BodyPart</th>\n",
" <th>Number_Intervening_Opponents</th>\n",
" <th>Number_Intervening_Teammates</th>\n",
" <th>Interference_on_Shooter</th>\n",
" <th>outcome</th>\n",
" <th>...</th>\n",
" <th>Interference_on_Shooter_Code</th>\n",
" <th>distance_to_goalM</th>\n",
" <th>distance_to_centerM</th>\n",
" <th>angle</th>\n",
" <th>isFoot</th>\n",
" <th>isHead</th>\n",
" <th>header_distance_to_goalM</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Medium</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>29</td>\n",
" <td>54</td>\n",
" <td>23.69</td>\n",
" <td>4.99</td>\n",
" <td>Open Play</td>\n",
" <td>Left</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>Medium</td>\n",
" <td>Missed</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>24.212265</td>\n",
" <td>5.001769</td>\n",
" <td>11.922004</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>11</td>\n",
" <td>33</td>\n",
" <td>28.93</td>\n",
" <td>-11.22</td>\n",
" <td>Open Play</td>\n",
" <td>Left</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>Low</td>\n",
" <td>Missed</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>31.039134</td>\n",
" <td>11.246462</td>\n",
" <td>21.243463</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>61</td>\n",
" <td>25</td>\n",
" <td>9.98</td>\n",
" <td>-5.24</td>\n",
" <td>Open Play</td>\n",
" <td>Head</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>High</td>\n",
" <td>Missed</td>\n",
" <td>...</td>\n",
" <td>3</td>\n",
" <td>11.277751</td>\n",
" <td>5.252358</td>\n",
" <td>27.757313</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>11.277751</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>73</td>\n",
" <td>45</td>\n",
" <td>4.49</td>\n",
" <td>-5.74</td>\n",
" <td>Open Play</td>\n",
" <td>Right</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>Low</td>\n",
" <td>Missed</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>7.298171</td>\n",
" <td>5.753538</td>\n",
" <td>52.031899</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>44</td>\n",
" <td>40</td>\n",
" <td>7.98</td>\n",
" <td>-12.97</td>\n",
" <td>Open Play</td>\n",
" <td>Right</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Medium</td>\n",
" <td>Saved</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>15.254368</td>\n",
" <td>13.000590</td>\n",
" <td>58.457635</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 29 columns</p>\n",
"</div>"
],
"text/plain": [
" match_minute match_second position_x position_y play_type BodyPart \\\n",
"0 29 54 23.69 4.99 Open Play Left \n",
"1 11 33 28.93 -11.22 Open Play Left \n",
"2 61 25 9.98 -5.24 Open Play Head \n",
"3 73 45 4.49 -5.74 Open Play Right \n",
"4 44 40 7.98 -12.97 Open Play Right \n",
"\n",
" Number_Intervening_Opponents Number_Intervening_Teammates \\\n",
"0 4 2 \n",
"1 4 1 \n",
"2 3 1 \n",
"3 2 0 \n",
"4 1 0 \n",
"\n",
" Interference_on_Shooter outcome ... Interference_on_Shooter_Code \\\n",
"0 Medium Missed ... 2 \n",
"1 Low Missed ... 1 \n",
"2 High Missed ... 3 \n",
"3 Low Missed ... 1 \n",
"4 Medium Saved ... 2 \n",
"\n",
" distance_to_goalM distance_to_centerM angle isFoot isHead \\\n",
"0 24.212265 5.001769 11.922004 1 0 \n",
"1 31.039134 11.246462 21.243463 1 0 \n",
"2 11.277751 5.252358 27.757313 0 1 \n",
"3 7.298171 5.753538 52.031899 1 0 \n",
"4 15.254368 13.000590 58.457635 1 0 \n",
"\n",
" header_distance_to_goalM High Low Medium \n",
"0 0.000000 0 0 1 \n",
"1 0.000000 0 1 0 \n",
"2 11.277751 1 0 0 \n",
"3 0.000000 0 1 0 \n",
"4 0.000000 0 0 1 \n",
"\n",
"[5 rows x 29 columns]"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.head()"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>isGoal</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" isGoal\n",
"0 0\n",
"1 0\n",
"2 0\n",
"3 0\n",
"4 0"
]
},
"execution_count": 168,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Przygotowanie danych"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Uwzględnienie wybranych cech: \n",
"- Współrzędna x strzelającego,\n",
"- Współrzędna y strzelającego,\n",
"- Dystans do bramki,\n",
"- Kąt do bramki,\n",
"- Minuta meczu,\n",
"- Liczba przeciwników przed piłką,\n",
"- Liczba zawodników ze swojej drużyny przed piłką,\n",
"- Część ciała."
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['match_minute', 'match_second', 'position_x', 'position_y', 'play_type',\n",
" 'BodyPart', 'Number_Intervening_Opponents',\n",
" 'Number_Intervening_Teammates', 'Interference_on_Shooter', 'outcome',\n",
" 'position_xM', 'position_yM', 'position_xM_r', 'position_yM_r',\n",
" 'position_xM_std', 'position_yM_std', 'position_xM_std_r',\n",
" 'position_yM_std_r', 'BodyPartCode', 'Interference_on_Shooter_Code',\n",
" 'distance_to_goalM', 'distance_to_centerM', 'angle', 'isFoot', 'isHead',\n",
" 'header_distance_to_goalM', 'High', 'Low', 'Medium'],\n",
" dtype='object')"
]
},
"execution_count": 169,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.columns"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {},
"outputs": [],
"source": [
"X_extracted = X[['position_x', \n",
" 'position_y',\n",
" 'distance_to_goalM', \n",
" 'angle', \n",
" 'match_minute', \n",
" 'Number_Intervening_Opponents', \n",
" 'Number_Intervening_Teammates', \n",
" 'isFoot', \n",
" 'isHead']]"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\s478991\\AppData\\Local\\temp\\ipykernel_3956\\2392787789.py:1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" X_extracted['isFoot'] = X_extracted['isFoot'].astype('category')\n",
"C:\\Users\\s478991\\AppData\\Local\\temp\\ipykernel_3956\\2392787789.py:2: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" X_extracted['isHead'] = X_extracted['isHead'].astype('category')\n"
]
}
],
"source": [
"X_extracted['isFoot'] = X_extracted['isFoot'].astype('category')\n",
"X_extracted['isHead'] = X_extracted['isHead'].astype('category')"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>position_x</th>\n",
" <th>position_y</th>\n",
" <th>distance_to_goalM</th>\n",
" <th>angle</th>\n",
" <th>match_minute</th>\n",
" <th>Number_Intervening_Opponents</th>\n",
" <th>Number_Intervening_Teammates</th>\n",
" <th>isFoot</th>\n",
" <th>isHead</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>23.69</td>\n",
" <td>4.99</td>\n",
" <td>24.212265</td>\n",
" <td>11.922004</td>\n",
" <td>29</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>28.93</td>\n",
" <td>-11.22</td>\n",
" <td>31.039134</td>\n",
" <td>21.243463</td>\n",
" <td>11</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>9.98</td>\n",
" <td>-5.24</td>\n",
" <td>11.277751</td>\n",
" <td>27.757313</td>\n",
" <td>61</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.49</td>\n",
" <td>-5.74</td>\n",
" <td>7.298171</td>\n",
" <td>52.031899</td>\n",
" <td>73</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.98</td>\n",
" <td>-12.97</td>\n",
" <td>15.254368</td>\n",
" <td>58.457635</td>\n",
" <td>44</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" position_x position_y distance_to_goalM angle match_minute \\\n",
"0 23.69 4.99 24.212265 11.922004 29 \n",
"1 28.93 -11.22 31.039134 21.243463 11 \n",
"2 9.98 -5.24 11.277751 27.757313 61 \n",
"3 4.49 -5.74 7.298171 52.031899 73 \n",
"4 7.98 -12.97 15.254368 58.457635 44 \n",
"\n",
" Number_Intervening_Opponents Number_Intervening_Teammates isFoot isHead \n",
"0 4 2 1 0 \n",
"1 4 1 1 0 \n",
"2 3 1 0 1 \n",
"3 2 0 1 0 \n",
"4 1 0 1 0 "
]
},
"execution_count": 172,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_extracted.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Podział danych na zbiór treningowy oraz zbiór testowy"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X_extracted, y, test_size=0.2, random_state=1)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [],
"source": [
"cv_outer = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)\n",
"cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Oddane strzały w zbiorze danych: 7226\n",
"Gole trafione w zbiorze danych: 906\n"
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"print ('Oddane strzały w zbiorze danych: ', count_class_0)\n",
"print ('Gole trafione w zbiorze danych: ', count_class_1)"
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7.975717439293598"
]
},
"execution_count": 176,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Class imbalance in training data\n",
"\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"scale_pos_weight"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Trening danych"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {},
"outputs": [],
"source": [
"from xgboost import XGBClassifier"
]
},
{
"cell_type": "code",
"execution_count": 178,
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
]
},
{
"cell_type": "code",
"execution_count": 179,
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
" 'max_depth': [3, 5, 7, 8, 9],\n",
" 'n_estimators': [100, 150, 200, 250, 300],\n",
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
"execution_count": 180,
"metadata": {},
"outputs": [],
"source": [
"start_time = time.time()"
]
},
{
"cell_type": "code",
"execution_count": 181,
"metadata": {},
"outputs": [],
"source": [
"# Perform nested cross-validation with grid search\n",
"\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv_inner, scoring='f1', n_jobs=-1)\n",
"scores_xg = cross_val_score(grid_xg, X_train, y_train, cv=cv_outer, scoring='f1', n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {},
"outputs": [],
"source": [
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"best_xgb_model = grid_xg.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 183,
"metadata": {},
"outputs": [],
"source": [
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
"execution_count": 184,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters: {'learning_rate': 0.001, 'max_depth': 3, 'n_estimators': 100, 'scale_pos_weight': 7.975717439293598}\n",
"Model Training Time: 677.443 seconds\n"
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ewaluacja modelu"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dane treningowe"
]
},
{
"cell_type": "code",
"execution_count": 185,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Confusion Matrix - Train Set')"
]
},
"execution_count": 185,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, best_xgb_model.predict(X_train))\n",
"\n",
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')"
]
},
{
"cell_type": "code",
"execution_count": 186,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.94 0.84 0.88 7226\n",
" 1 0.30 0.56 0.39 906\n",
"\n",
" accuracy 0.81 8132\n",
" macro avg 0.62 0.70 0.64 8132\n",
"weighted avg 0.87 0.81 0.83 8132\n",
"\n"
]
}
],
"source": [
"# Classfication report for training data\n",
"print (classification_report(y_train, best_xgb_model.predict(X_train)))"
]
},
{
"cell_type": "code",
"execution_count": 187,
"metadata": {},
"outputs": [],
"source": [
"# xgb.to_graphviz(best_xgb_model, num_trees=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dane testowe"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Confusion Matrix - Test Set')"
]
},
"execution_count": 188,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the performance of the best model on the testing set\n",
"y_pred_xgb = best_xgb_model.predict(X_test)\n",
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_xgb)\n",
"\n",
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')"
]
},
{
"cell_type": "code",
"execution_count": 189,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.93 0.84 0.88 1797\n",
" 1 0.30 0.50 0.37 236\n",
"\n",
" accuracy 0.80 2033\n",
" macro avg 0.61 0.67 0.63 2033\n",
"weighted avg 0.85 0.80 0.82 2033\n",
"\n"
]
}
],
"source": [
"# Classfication report for testing data\n",
"print (classification_report(y_test, y_pred_xgb))"
]
},
{
"cell_type": "code",
"execution_count": 191,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Zbiór danych testowych zawiera 2033 oddane strzały, gdzie 236 to strzały trafione.\n",
"Dokładność klasyfikacji, czy strzał jest bramką, czy nie, wynosi 0.80%.\n",
"klasyfikator uzyskał ROC-AUC na poziomie 0.75%.\n"
]
}
],
"source": [
"print(f'Zbiór danych testowych zawiera {len(y_test)} oddane strzały, gdzie {y_test.sum()[\"isGoal\"]} to strzały trafione.')\n",
"print(f'Dokładność klasyfikacji, czy strzał jest bramką, czy nie, wynosi {best_xgb_model.score(X_test, y_test):.2f}%.')\n",
"print(f'klasyfikator uzyskał ROC-AUC na poziomie {roc_auc_score(y_test, best_xgb_model.predict_proba(X_test)[:, 1]):.2f}%.')"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance\n",
"xgb.plot_importance(best_xgb_model)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"xgb.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"xgb.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Podsumowanie"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [],
"source": [
"prec_xgb_train = precision_score(y_train, best_xgb_model.predict(X_train))\n",
"prec_xgb_test = precision_score(y_test, y_pred_xgb)\n",
"rec_xgb_train = recall_score(y_train, best_xgb_model.predict(X_train))\n",
"rec_xgb_test = recall_score(y_test, y_pred_xgb)\n",
"acc_xgb_train = accuracy_score(y_train, best_xgb_model.predict(X_train))\n",
"acc_xgb_test = accuracy_score(y_test, y_pred_xgb)\n",
"train_time = xgb_training_time/60"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_c3a1d_row0_col0, #T_c3a1d_row0_col1, #T_c3a1d_row0_col2, #T_c3a1d_row0_col3, #T_c3a1d_row0_col4, #T_c3a1d_row0_col5, #T_c3a1d_row0_col6 {\n",
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
"<table id=\"T_c3a1d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_c3a1d_level0_col0\" class=\"col_heading level0 col0\" >Training Accuracy</th>\n",
" <th id=\"T_c3a1d_level0_col1\" class=\"col_heading level0 col1\" >Training Precision</th>\n",
" <th id=\"T_c3a1d_level0_col2\" class=\"col_heading level0 col2\" >Training Recall</th>\n",
" <th id=\"T_c3a1d_level0_col3\" class=\"col_heading level0 col3\" >Testing Accuracy</th>\n",
" <th id=\"T_c3a1d_level0_col4\" class=\"col_heading level0 col4\" >Testing Precision</th>\n",
" <th id=\"T_c3a1d_level0_col5\" class=\"col_heading level0 col5\" >Testing Recall</th>\n",
" <th id=\"T_c3a1d_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_c3a1d_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_c3a1d_row0_col0\" class=\"data row0 col0\" >0.806</td>\n",
" <td id=\"T_c3a1d_row0_col1\" class=\"data row0 col1\" >0.300</td>\n",
" <td id=\"T_c3a1d_row0_col2\" class=\"data row0 col2\" >0.556</td>\n",
" <td id=\"T_c3a1d_row0_col3\" class=\"data row0 col3\" >0.803</td>\n",
" <td id=\"T_c3a1d_row0_col4\" class=\"data row0 col4\" >0.296</td>\n",
" <td id=\"T_c3a1d_row0_col5\" class=\"data row0 col5\" >0.504</td>\n",
" <td id=\"T_c3a1d_row0_col6\" class=\"data row0 col6\" >11.291</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x163c80a91b0>"
]
},
"execution_count": 196,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training Accuracy': acc_xgb_train, \n",
" 'Training Precision': prec_xgb_train,\n",
" 'Training Recall':rec_xgb_train,\n",
" 'Testing Accuracy': acc_xgb_test, \n",
" 'Testing Precision': prec_xgb_test,\n",
" 'Testing Recall':rec_xgb_test,\n",
" 'Training Time (mins)': train_time})\n",
"\n",
"summary_df.set_index('Model Name', inplace=True)\n",
"# Displaying summary of results\n",
"summary_df.style.format(precision =3).set_properties(**{'font-weight': 'bold',\n",
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Zapisywanie modelu"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['xgboost.joblib']"
]
},
"execution_count": 197,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from joblib import dump\n",
"dump(best_xgb_model, 'xgboost.joblib') "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Wczytywanie modelu"
]
},
{
"cell_type": "code",
"execution_count": 198,
"metadata": {},
"outputs": [],
"source": [
"from joblib import load\n",
"\n",
"model2 = load('xgboost.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'objective': 'binary:logistic',\n",
" 'base_score': None,\n",
" 'booster': None,\n",
" 'callbacks': None,\n",
" 'colsample_bylevel': None,\n",
" 'colsample_bynode': None,\n",
" 'colsample_bytree': None,\n",
" 'device': None,\n",
" 'early_stopping_rounds': None,\n",
" 'enable_categorical': True,\n",
" 'eval_metric': None,\n",
" 'feature_types': None,\n",
" 'gamma': None,\n",
" 'grow_policy': None,\n",
" 'importance_type': None,\n",
" 'interaction_constraints': None,\n",
" 'learning_rate': 0.001,\n",
" 'max_bin': None,\n",
" 'max_cat_threshold': None,\n",
" 'max_cat_to_onehot': None,\n",
" 'max_delta_step': None,\n",
" 'max_depth': 3,\n",
" 'max_leaves': None,\n",
" 'min_child_weight': None,\n",
" 'missing': nan,\n",
" 'monotone_constraints': None,\n",
" 'multi_strategy': None,\n",
" 'n_estimators': 100,\n",
" 'n_jobs': None,\n",
" 'num_parallel_tree': None,\n",
" 'random_state': None,\n",
" 'reg_alpha': None,\n",
" 'reg_lambda': None,\n",
" 'sampling_method': None,\n",
" 'scale_pos_weight': 7.975717439293598,\n",
" 'subsample': None,\n",
" 'tree_method': 'hist',\n",
" 'validate_parameters': None,\n",
" 'verbosity': None}"
]
},
"execution_count": 199,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model2.get_params()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}