fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb

769 lines
367 KiB
Plaintext
Raw Permalink Normal View History

2023-12-12 15:22:01 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
"import xgboost\n",
"import time\n",
"from joblib import dump, load"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the data"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 2,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('final_data.txt')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 3,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
" 'shot_aerial_won', 'shot_deflected', 'shot_open_goal',\n",
" 'shot_follows_dribble', 'shot_redirect', 'x1', 'y1',\n",
" 'number_of_players_opponents', 'number_of_players_teammates', 'is_goal',\n",
" 'angle', 'distance', 'x_player_opponent_Goalkeeper',\n",
" 'x_player_opponent_8', 'x_player_opponent_1', 'x_player_opponent_2',\n",
" 'x_player_opponent_3', 'x_player_teammate_1', 'x_player_opponent_4',\n",
" 'x_player_opponent_5', 'x_player_opponent_6', 'x_player_teammate_2',\n",
" 'x_player_opponent_9', 'x_player_opponent_10', 'x_player_opponent_11',\n",
" 'x_player_teammate_3', 'x_player_teammate_4', 'x_player_teammate_5',\n",
" 'x_player_teammate_6', 'x_player_teammate_7', 'x_player_teammate_8',\n",
" 'x_player_teammate_9', 'x_player_teammate_10',\n",
" 'y_player_opponent_Goalkeeper', 'y_player_opponent_8',\n",
" 'y_player_opponent_1', 'y_player_opponent_2', 'y_player_opponent_3',\n",
" 'y_player_teammate_1', 'y_player_opponent_4', 'y_player_opponent_5',\n",
" 'y_player_opponent_6', 'y_player_teammate_2', 'y_player_opponent_9',\n",
" 'y_player_opponent_10', 'y_player_opponent_11', 'y_player_teammate_3',\n",
" 'y_player_teammate_4', 'y_player_teammate_5', 'y_player_teammate_6',\n",
" 'y_player_teammate_7', 'y_player_teammate_8', 'y_player_teammate_9',\n",
" 'y_player_teammate_10', 'x_player_opponent_7', 'y_player_opponent_7',\n",
" 'x_player_teammate_Goalkeeper', 'y_player_teammate_Goalkeeper',\n",
" 'shot_kick_off'],\n",
" dtype='object')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
2023-12-12 15:22:01 +01:00
"source": [
"df.columns"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_deflected</th>\n",
" <th>shot_open_goal</th>\n",
2023-12-12 15:22:01 +01:00
" <th>...</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
" <th>shot_kick_off</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>Center Midfield</td>\n",
" <td>Right Foot</td>\n",
" <td>Half Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>Left Center Midfield</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Right Center Back</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 66 columns</p>\n",
2023-12-12 15:22:01 +01:00
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
"0 0 Right Center Forward Right Foot Normal \n",
"1 5 Right Center Forward Left Foot Normal \n",
"2 5 Center Midfield Right Foot Half Volley \n",
"3 5 Left Center Midfield Right Foot Normal \n",
"4 5 Right Center Back Left Foot Normal \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
"0 Open Play False False False \n",
"1 Open Play False False False \n",
"2 Open Play True False False \n",
"3 Open Play False False False \n",
"4 Open Play True False False \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_deflected shot_open_goal ... y_player_teammate_6 \\\n",
"0 False False ... NaN \n",
"1 False False ... NaN \n",
"2 False False ... NaN \n",
"3 False False ... NaN \n",
"4 False False ... NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
" y_player_teammate_7 y_player_teammate_8 y_player_teammate_9 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
" y_player_teammate_10 x_player_opponent_7 y_player_opponent_7 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
" x_player_teammate_Goalkeeper y_player_teammate_Goalkeeper shot_kick_off \n",
"0 NaN NaN False \n",
"1 NaN NaN False \n",
"2 NaN NaN False \n",
"3 NaN NaN False \n",
"4 NaN NaN False \n",
2023-12-12 15:22:01 +01:00
"\n",
"[5 rows x 66 columns]"
2023-12-12 15:22:01 +01:00
]
},
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
2023-12-12 15:22:01 +01:00
"source": [
"# Change the type of categorical features to 'category' \n",
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']] = df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']].astype('category')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 6,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Splitting the dataset into features (X) and the target variable (y)\n",
"y = pd.DataFrame(df['is_goal'])\n",
"X = df.drop(['is_goal'], axis=1)\n",
"\n",
"# Splitting the data into a training set and a test set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
"\n",
"# Create cross-validation \n",
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2023-12-12 15:22:01 +01:00
"output_type": "stream",
"text": [
"Shots attempted in the training set: 27085\n",
"Goals scored in the training set: 3588\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"\n",
"# Display the count of shots attempted in the training set\n",
"print('Shots attempted in the training set:', count_class_0)\n",
"\n",
"# Display the count of successful goals in the training set\n",
"print('Goals scored in the training set:', count_class_1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 8,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Class imbalance in training data: 7.549\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Class imbalance in training data\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training XGBoost model "
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 9,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 10,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
" 'max_depth': [3, 5, 7, 8, 9],\n",
" 'n_estimators': [100, 150, 200, 250, 300],\n",
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
"execution_count": 11,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Starting the timer\n",
"start_time = time.time()\n",
"\n",
"# Perform grid search with cross-validation\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"\n",
"# Take the best parameters for xgboost model\n",
"best_xgb_model = grid_xg.best_estimator_\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
"execution_count": 12,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters: {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
"Model Training Time: 1916.225 seconds\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model evaluation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 13,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABKMElEQVR4nO3deVwV9f7H8fcB5YALoCEC7ksuuJYakqmZJK6JS7ndRFNLL5aKW1ZXUSvSLJfcskWstGtWWmlppCmZWKaRS2pqmmmCuwQqIMzvD3+e6wkX8JyRo76e9zGPe893vmfmO4M+7tvP9zuDxTAMQwAAAC7MraAHAAAAcD0EFgAA4PIILAAAwOURWAAAgMsjsAAAAJdHYAEAAC6PwAIAAFwegQUAALg8AgsAAHB5BBbcMvbs2aNWrVrJx8dHFotFy5Ytc+rxDxw4IIvFori4OKce91b24IMP6sEHHyzoYZiCnzdwayGwIF/27dunp556SpUrV5anp6e8vb3VpEkTTZ8+XefOnTP13JGRkdq2bZteeuklvf/++2rYsKGp57uZ+vTpI4vFIm9v7yvexz179shischisWjKlCn5Pv5ff/2lmJgYJSUlOWG05oqJibFd67U2VwxSOTk5eu+99xQSEqKSJUuqePHiqlatmnr37q2NGzfm+3hnz55VTEyM1q5d6/zBAreYQgU9ANw6VqxYoUcffVRWq1W9e/dW7dq1lZmZqfXr12vkyJHasWOH5s2bZ8q5z507p8TERD3//PMaPHiwKeeoUKGCzp07p8KFC5ty/OspVKiQzp49qy+++EKPPfaY3b6FCxfK09NT58+fv6Fj//XXXxo/frwqVqyo+vXr5/l7X3/99Q2dzxGdO3dW1apVbZ/T0tI0aNAgderUSZ07d7a1ly5d2qHzmPHzfuaZZzRr1ix17NhRvXr1UqFChbR792599dVXqly5sho3bpyv4509e1bjx4+XJJcMaMDNRGBBnuzfv1/du3dXhQoVtGbNGgUGBtr2RUVFae/evVqxYoVp5z927JgkydfX17RzWCwWeXp6mnb867FarWrSpIk+/PDDXIFl0aJFateunT755JObMpazZ8+qSJEi8vDwuCnnu1zdunVVt25d2+fjx49r0KBBqlu3rv71r39d9Xvnz5+Xh4eH3NzyVjh29s87JSVFs2fP1oABA3IF92nTptn+DAO4MUwJIU8mT56stLQ0vfPOO3Zh5ZKqVatqyJAhts8XLlzQxIkTVaVKFVmtVlWsWFHPPfecMjIy7L5XsWJFtW/fXuvXr9d9990nT09PVa5cWe+9956tT0xMjCpUqCBJGjlypCwWiypWrCjp4lTKpf99uUvTCpeLj4/XAw88IF9fXxUrVkzVq1fXc889Z9t/tTUNa9asUdOmTVW0aFH5+vqqY8eO2rlz5xXPt3fvXvXp00e+vr7y8fFR3759dfbs2avf2H/o2bOnvvrqK50+fdrWtmnTJu3Zs0c9e/bM1f/kyZMaMWKE6tSpo2LFisnb21tt2rTRL7/8Yuuzdu1aNWrUSJLUt29f25TKpet88MEHVbt2bW3evFnNmjVTkSJFbPfln2tYIiMj5enpmev6w8PDVaJECf311195vlZHrF27VhaLRf/973/1wgsvqEyZMipSpIhSU1PzdE+kK/+8+/Tpo2LFiunw4cOKiIhQsWLFVKpUKY0YMULZ2dnXHNP+/ftlGIaaNGmSa5/FYpG/v79d2+nTpzV06FCVK1dOVqtVVatW1aRJk5STk2MbX6lSpSRJ48ePt/3cYmJibuCOAbc+KizIky+++EKVK1fW/fffn6f+/fv314IFC9S1a1cNHz5cP/zwg2JjY7Vz504tXbrUru/evXvVtWtX9evXT5GRkXr33XfVp08fNWjQQLVq1VLnzp3l6+urYcOGqUePHmrbtq2KFSuWr/Hv2LFD7du3V926dTVhwgRZrVbt3btX33///TW/980336hNmzaqXLmyYmJidO7cOb3xxhtq0qSJtmzZkissPfbYY6pUqZJiY2O1ZcsWvf322/L399ekSZPyNM7OnTtr4MCB+vTTT/XEE09IulhdqVGjhu69995c/X///XctW7ZMjz76qCpVqqSUlBS9+eabat68uX799VcFBQWpZs2amjBhgsaOHasnn3xSTZs2lSS7n+WJEyfUpk0bde/eXf/617+uOt0yffp0rVmzRpGRkUpMTJS7u7vefPNNff3113r//fcVFBSUp+t0lokTJ8rDw0MjRoxQRkaGPDw89Ouvv173nlxLdna2wsPDFRISoilTpuibb77Ra6+9pipVqmjQoEFX/d6lUL1kyRI9+uijKlKkyFX7nj17Vs2bN9fhw4f11FNPqXz58tqwYYPGjBmjI0eOaNq0aSpVqpTmzJmTazrs8uoTcEcxgOs4c+aMIcno2LFjnvonJSUZkoz+/fvbtY8YMcKQZKxZs8bWVqFCBUOSkZCQYGs7evSoYbVajeHDh9va9u/fb0gyXn31VbtjRkZGGhUqVMg1hnHjxhmX//GeOnWqIck4duzYVcd96Rzz58+3tdWvX9/w9/c3Tpw4YWv75ZdfDDc3N6N37965zvfEE0/YHbNTp07GXXfdddVzXn4dRYsWNQzDMLp27Wq0bNnSMAzDyM7ONgICAozx48df8R6cP3/eyM7OznUdVqvVmDBhgq1t06ZNua7tkubNmxuSjLlz515xX/Pmze3aVq1aZUgyXnzxReP33383ihUrZkRERFz3Gm/UsWPHDEnGuHHjbG3ffvutIcmoXLmycfbsWbv+eb0nV/p5R0ZGGpLs+hmGYdxzzz1GgwYNrjvW3r17G5KMEiVKGJ06dTKmTJli7Ny5M1e/iRMnGkWLFjV+++03u/Znn33WcHd3Nw4ePHjVawfuVEwJ4bpSU1MlScWLF89T/y+//FKSFB0dbdc+fPhwScq11iU4ONj2r35JKlWqlKpXr67ff//9hsf8T5fWvnz22We2kvv1HDlyRElJSerTp49Klixpa69bt64efvhh23VebuDAgXafmzZtqhMnTtjuYV707NlTa9euVXJystasWaPk5OQrTgdJF9e9XFqzkZ2drRMnTtimu7Zs2ZLnc1qtVvXt2zdPfVu1aqWnnnpKEyZMUOfOneXp6ak333wzz+dypsjISHl5edm1OeOeXOnnmJc/j/Pnz9fMmTNVqVIlLV26VCNGjFDNmjXVsmVLHT582NZvyZIlatq0qUqUKKHjx4/btrCwMGVnZyshISFP4wTuJAQWXJe3t7ck6e+//85T/z/++ENubm52T3pIUkBAgHx9ffXHH3/YtZcvXz7XMUqUKKFTp07d4Ihz69atm5o0aaL+/furdOnS6t69uz766KNrhpdL46xevXqufTVr1tTx48eVnp5u1/7PaylRooQk5eta2rZtq+LFi2vx4sVauHChGjVqlOteXpKTk6OpU6fq7rvvltVqlZ+fn0qVKqWtW7fqzJkzeT5nmTJl8rXAdsqUKSpZsqSSkpI0Y8aMXOszruTYsWNKTk62bWlpaXk+39VUqlQpV5uj98TT09O2duSSvP55dHNzU1RUlDZv3qzjx4/rs88+U5s2bbRmzRp1797d1m/Pnj1auXKlSpUqZbeFhYVJko4ePXrdcwF3GgILrsvb21tBQUHavn17vr73z0WvV+Pu7n7FdsMwbvgc/1wg6eXlpYSEBH3zzTd6/PHHtXXrVnXr1k0PP/zwdRdT5ocj13KJ1WpV586dtWDBAi1duvSq1RVJevnllxUdHa1mzZrpgw8+0KpVqxQfH69atWrluZIkKVeV4np+/vln2/+pbtu2LU/fadSokQIDA23bjbxP5p+uNG5H78nVfob5ddddd+mRRx7Rl19+qebNm2v9+vW2EJyTk6OHH35Y8fHxV9y6dOnilDEAtxMW3SJP2rdvr3nz5ikxMVGhoaHX7FuhQgXl5ORoz549qlmzpq09JSVFp0+fti1OdIYSJUrYPVFzyT+rONLFf/22bNlSLVu21Ouvv66XX35Zzz//vL799lvbv2z/eR2StHv37lz7du3aJT8/PxUtWtTxi7i
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on training set\n",
"y_pred_train = best_xgb_model.predict(X_train)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 14,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAHHCAYAAACPy0PBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABJt0lEQVR4nO3deVxU1f/H8feAMiAKiMiWG2Yu5FJiKbknSYbmmlmWuFUamuJOm6YlZYtpbu1YaWn11Uxz+2pqJW4UZW65FZqCW4grKNzfH/6cryPqgM11FF/PHvfxaO49c++54/bmc865YzEMwxAAAIALubm6AwAAAAQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQS3LC2b9+uli1bytfXVxaLRXPnznXq+f/8809ZLBYlJSU59bw3smbNmqlZs2au7gaAIohAgn9l586deuqpp1S5cmV5enrKx8dHDRs21IQJE3Tq1ClTrx0bG6uNGzfqlVde0aeffqp69eqZer1rqXv37rJYLPLx8bnk57h9+3ZZLBZZLBa98cYbhT7/vn37NGrUKKWmpjqht+YaNWqU7V6vtDkrKH333XcaNWpUgdvn5eXpk08+Uf369eXv769SpUqpatWq6tatm9asWVPo6588eVKjRo3SihUrCv1e4EZWzNUdwI1rwYIFeuihh2S1WtWtWzfVrFlTOTk5+vHHHzV06FBt2rRJ7733ninXPnXqlJKTk/Xcc8+pX79+plyjYsWKOnXqlIoXL27K+R0pVqyYTp48qW+//VadO3e2OzZjxgx5enrq9OnTV3Xuffv26aWXXlKlSpV0xx13FPh9S5Ysuarr/RsdOnRQlSpVbK+PHz+uvn37qn379urQoYNtf1BQkFOu991332ny5MkFDiXPPPOMJk+erLZt26pr164qVqyYtm3bpoULF6py5cpq0KBBoa5/8uRJvfTSS5JENQo3FQIJrsru3bvVpUsXVaxYUcuXL1dISIjtWFxcnHbs2KEFCxaYdv2DBw9Kkvz8/Ey7hsVikaenp2nnd8Rqtaphw4b6/PPP8wWSmTNnKiYmRl9//fU16cvJkydVokQJeXh4XJPrXah27dqqXbu27fWhQ4fUt29f1a5dW4899tg178+FMjIyNGXKFD3xxBP5wvfbb79t+30KwDGGbHBVxo0bp+PHj+vDDz+0CyPnValSRQMGDLC9Pnv2rMaMGaNbb71VVqtVlSpV0rPPPqvs7Gy791WqVEmtW7fWjz/+qLvvvluenp6qXLmyPvnkE1ubUaNGqWLFipKkoUOHymKxqFKlSpLODXWc//8LnS/7X2jp0qVq1KiR/Pz8VLJkSVWrVk3PPvus7fjl5pAsX75cjRs3lre3t/z8/NS2bVtt2bLlktfbsWOHunfvLj8/P/n6+qpHjx46efLk5T/Yizz66KNauHChMjMzbfvWr1+v7du369FHH83X/siRIxoyZIhq1aqlkiVLysfHR61atdKvv/5qa7NixQrdddddkqQePXrYhjzO32ezZs1Us2ZNpaSkqEmTJipRooTtc7l4DklsbKw8PT3z3X90dLRKly6tffv2Ffhe/62tW7eqU6dO8vf3l6enp+rVq6d58+bZtTlz5oxeeukl3XbbbfL09FSZMmXUqFEjLV26VNK53z+TJ0+WJLvhoMvZvXu3DMNQw4YN8x2zWCwKDAy025eZmamBAweqfPnyslqtqlKlil577TXl5eVJOvd7rmzZspKkl156yXb9wgwhATcqKiS4Kt9++60qV66se+65p0Dte/furenTp6tTp04aPHiw1q5dq8TERG3ZskVz5syxa7tjxw516tRJvXr1UmxsrD766CN1795dERERuv3229WhQwf5+fkpPj5ejzzyiB544AGVLFmyUP3ftGmTWrdurdq1a2v06NGyWq3asWOHfvrppyu+77///a9atWqlypUra9SoUTp16pTeeecdNWzYUD///HO+MNS5c2eFhYUpMTFRP//8sz744AMFBgbqtddeK1A/O3TooD59+ug///mPevbsKelcdaR69eqqW7duvva7du3S3Llz9dBDDyksLEwZGRl699131bRpU23evFmhoaGqUaOGRo8erRdffFFPPvmkGjduLEl2v5aHDx9Wq1at1KVLFz322GOXHQ6ZMGGCli9frtjYWCUnJ8vd3V3vvvuulixZok8//VShoaEFus9/a9OmTWrYsKFuueUWjRgxQt7e3po9e7batWunr7/+Wu3bt5d0LigmJiaqd+/euvvuu5WVlaUNGzbo559/1n333aennnpK+/bt09KlS/Xpp586vO75YPzll1/qoYceUokSJS7b9uTJk2ratKn+/vtvPfXUU6pQoYJWr16thIQE7d+/X2+//bbKli2rqVOn5huSurBCBBRZBlBIR48eNSQZbdu2LVD71NRUQ5LRu3dvu/1DhgwxJBnLly+37atYsaIhyVi1apVt34EDBwyr1WoMHjzYtm/37t2GJOP111+3O2dsbKxRsWLFfH0YOXKkceFv9/HjxxuSjIMHD1623+ev8fHHH9v23XHHHUZgYKBx+PBh275ff/3VcHNzM7p165bvej179rQ7Z/v27Y0yZcpc9poX3oe3t7dhGIbRqVMno0WLFoZhGEZubq4RHBxsvPTSS5f8DE6fPm3k5ubmuw+r1WqMHj3atm/9+vX57u28pk2bGpKMadOmXfJY06ZN7fYtXrzYkGS8/PLLxq5du4ySJUsa7dq1c3iPV+vgwYOGJGPkyJG2fS1atDBq1aplnD592rYvLy/PuOeee4zbbrvNtq9OnTpGTEzMFc8fFxdnFOavxm7duhmSjNKlSxvt27c33njjDWPLli352o0ZM8bw9vY2/vjjD7v9I0aMMNzd3Y20tLTL3h9wM2DIBoWWlZUlSSpVqlSB2n/33XeSpEGDBtntHzx4sCTlm2sSHh5u+6ldksqWLatq1app165dV93ni52fe/LNN9/YyuWO7N+/X6mpqerevbv8/f1t+2vXrq377rvPdp8X6tOnj93rxo0b6/Dhw7bPsCAeffRRrVixQunp6Vq+fLnS09MvOVwjnZt34uZ27o91bm6uDh8+bBuO+vnnnwt8TavVqh49ehSobcuWLfXUU09p9OjR6tChgzw9PfXuu+8W+Fr/1pEjR7R8+XJ17txZx44d06FDh3To0CEdPnxY0dHR2r59u/7++29J537dN23apO3btzvt+h9//LEmTZqksLAwzZkzR0OGDFGNGjXUokUL23Wlc1WUxo0bq3Tp0rY+Hjp0SFFRUcrNzdWqVauc1ifgRkQgQaH5+PhIko4dO1ag9n/99Zfc3NzsVkpIUnBwsPz8/PTXX3/Z7a9QoUK+c5QuXVr//PPPVfY4v4cfflgNGzZU7969FRQUpC5dumj27NlXDCfn+1mtWrV8x2rUqKFDhw7pxIkTdvsvvpfSpUtLUqHu5YEHHlCpUqU0a9YszZgxQ3fddVe+z/K8vLw8jR8/XrfddpusVqsCAgJUtmxZ/fbbbzp69GiBr3nLLbcUagLrG2+8IX9/f6WmpmrixIn55k5cysGDB5Wenm7bjh8/XuDrXWjHjh0yDEMvvPCCypYta7eNHDlSknTgwAFJ0ujRo5WZmamqVauqVq1aGjp0qH777beruu55bm5uiouLU0pKig4dOqRvvvlGrVq10vLly9WlSxdbu+3bt2vRokX5+hgVFWXXR+BmxRwSFJqPj49CQ0P1+++/F+p9V5oceCF3d/dL7jcM46qvkZuba/fay8tLq1at0vfff68FCxZo0aJFmjVrlu69914tWbLksn0orH9zL+dZrVZ16NBB06dP165du644wXHs2LF64YUX1LNnT40ZM0b+/v5yc3PTwIEDC1wJks59PoXxyy+/2P5B3bhxox555BGH77nrrrvswujIkSOvavLm+fsaMmSIoqOjL9nmfIBr0qSJdu7cqW+++UZLlizRBx98oPHjx2vatGnq3bt3oa99sTJlyujBBx/Ugw8+qGbNmmnlypX666+/VLFiReXl5em+++7
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on test set\n",
"y_pred_test = best_xgb_model.predict(X_test)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 15,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The test dataset contains 7669 shots, with 914 of them being goals.\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Number of goals in test set\n",
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
2023-12-12 15:22:01 +01:00
"metadata": {},
"source": [
"## Feature importance"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 16,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAHHCAYAAAAs6rBrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gUx//A8ffRuwIBBUVERcAG2MWCHUSxRbEr1p+9xV4oVuw9xo5GjSWxJSqKCsaKiMYuVqKxxC4CSrv5/eHDfr0cIKaJcV7Pw6M3Mzs7+9mDu9mZnVUJIQSSJEmSJEmSJElSvqTzsRsgSZIkSZIkSZIk5Ux22iRJkiRJkiRJkvIx2WmTJEmSJEmSJEnKx2SnTZIkSZIkSZIkKR+TnTZJkiRJkiRJkqR8THbaJEmSJEmSJEmS8jHZaZMkSZIkSZIkScrHZKdNkiRJkiRJkiQpH5OdNkmSJEmSJEmSpHxMdtokSZIk6RMSHh6OSqUiISHhYzdFkiRJ+pfITpskSZKUr2V1UrL7GTNmzD+yz+PHjxMSEsKLFy/+kfo/ZykpKYSEhBAdHf2xmyJJkvTJ0PvYDZAkSZKkvJg0aRJOTk4aaeXKlftH9nX8+HFCQ0MJDAykYMGC/8g+/qwuXbrQvn17DA0NP3ZT/pSUlBRCQ0MBqFu37sdtjCRJ0idCdtokSZKkT0KTJk2oXLnyx27GX5KcnIypqelfqkNXVxddXd2/qUX/HrVaTVpa2sduhiRJ0idJTo+UJEmS/hP27t1L7dq1MTU1xdzcnKZNm3Lp0iWNMufPnycwMJASJUpgZGRE4cKF6dGjB0+fPlXKhISEMHLkSACcnJyUqZgJCQkkJCSgUqkIDw/X2r9KpSIkJESjHpVKxeXLl+nYsSOWlpbUqlVLyV+/fj2VKlXC2NgYKysr2rdvz927d997nNnd01a8eHGaNWtGdHQ0lStXxtjYmPLlyytTELdt20b58uUxMjKiUqVKnD17VqPOwMBAzMzMuHXrFj4+PpiammJvb8+kSZMQQmiUTU5O5quvvsLBwQFDQ0NcXFyYPXu2VjmVSsXAgQPZsGEDZcuWxdDQkG+++QYbGxsAQkNDldhmxS0v5+fd2N64cUMZDS1QoADdu3cnJSVFK2br16+natWqmJiYYGlpSZ06ddi/f79Gmby8fyRJkj4WOdImSZIkfRJevnzJkydPNNK++OILAL799lu6deuGj48PM2bMICUlhaVLl1KrVi3Onj1L8eLFAYiMjOTWrVt0796dwoULc+nSJZYvX86lS5c4efIkKpWK1q1bc+3aNb777jvmzZun7MPGxobHjx9/cLvbtm2Ls7Mz06ZNUzo2U6dOZeLEiQQEBNCrVy8eP37MokWLqFOnDmfPnv1TUzJv3LhBx44d+b//+z86d+7M7Nmz8ff355tvvmHcuHH0798fgOnTpxMQEEB8fDw6Ov+7dpuZmYmvry/Vq1dn5syZREREEBwcTEZGBpMmTQJACEHz5s2JioqiZ8+eeHh4sG/fPkaOHMm9e/eYN2+eRpsOHTrEli1bGDhwIF988QXu7u4sXbqUfv360apVK1q3bg1AhQoVgLydn3cFBATg5OTE9OnTOXPmDCtXrsTW1pYZM2YoZUJDQwkJCcHLy4tJkyZhYGBATEwMhw4donHjxkDe3z+SJEkfjZAkSZKkfGzNmjUCyPZHCCFevXolChYsKHr37q2x3cOHD0WBAgU00lNSUrTq/+677wQgfv75ZyVt1qxZAhC3b9/WKHv79m0BiDVr1mjVA4jg4GDldXBwsABEhw4dNMolJCQIXV1dMXXqVI30CxcuCD09Pa30nOLxbtscHR0FII4fP66k7du3TwDC2NhY/Prrr0r6smXLBCCioqKUtG7duglADBo0SElTq9WiadOmwsDAQDx+/FgIIcSOHTsEIKZMmaLRpjZt2giVSiVu3LihEQ8dHR1x6dIljbKPHz/WilWWvJ6frNj26NFDo2yrVq2EtbW18vr69etCR0dHtGrVSmRmZmqUVavVQogPe/9IkiR9LHJ6pCRJkvRJWLJkCZGRkRo/8HZ05sWLF3To0IEnT54oP7q6ulSrVo2oqCilDmNjY+X/b9684cmTJ1SvXh2AM2fO/CPt7tu3r8brbdu2oVarCQgI0Ghv4cKFcXZ21mjvhyhTpgw1atRQXlerVg2A+vXrU6xYMa30W7duadUxcOBA5f9Z0xvT0tI4cOAAAHv27EFXV5fBgwdrbPfVV18hhGDv3r0a6d7e3pQpUybPx/Ch5+ePsa1duzZPnz4lMTERgB07dqBWqwkKCtIYVcw6Pviw948kSdLHIqdHSpIkSZ+EqlWrZrsQyfXr14G3nZPsWFhYKP9/9uwZoaGhbNq0iUePHmmUe/ny5d/Y2v/544qX169fRwiBs7NztuX19fX/1H7e7ZgBFChQAAAHB4ds058/f66RrqOjQ4kSJTTSSpcuDaDcP/frr79ib2+Pubm5Rjk3Nzcl/11/PPb3+dDz88djtrS0BN4em4WFBTdv3kRHRyfXjuOHvH8kSZI+FtlpkyRJkj5parUaeHtfUuHChbXy9fT+91EXEBDA8ePHGTlyJB4eHpiZmaFWq/H19VXqyc0f76nKkpmZmeM2744eZbVXpVKxd+/ebFeBNDMze287spPTipI5pYs/LBzyT/jjsb/Ph56fv+PYPuT9I0mS9LHIv0SSJEnSJ61kyZIA2Nra0rBhwxzLPX/+nIMHDxIaGkpQUJCSnjXS8q6cOmdZIzl/fOj2H0eY3tdeIQROTk7KSFZ+oFaruXXrlkabrl27BqAsxOHo6MiBAwd49eqVxmjb1atXlfz3ySm2H3J+8qpkyZKo1WouX76Mh4dHjmXg/e8fSZKkj0ne0yZJkiR90nx8fLCwsGDatGmkp6dr5Wet+Jg1KvPHUZj58+drbZP1LLU/ds4sLCz44osv+PnnnzXSv/766zy3t3Xr1ujq6hIaGqrVFiGE1vL2/6bFixdrtGXx4sXo6+vToEEDAPz8/MjMzNQoBzBv3jxUKhVNmjR57z5MTEwA7dh+yPnJq5YtW6Kjo8OkSZO0Ruqy9pPX948kSdLHJEfaJEmSpE+ahYUFS5cupUuXLlSsWJH27dtjY2PDnTt32L17NzVr1mTx4sVYWFhQp04dZs6cSXp6OkWKFGH//v3cvn1bq85KlSoBMH78eNq3b4++vj7+/v6YmprSq1cvwsLC6NWrF5UrV+bnn39WRqTyomTJkkyZMoWxY8eSkJBAy5YtMTc35/bt22zfvp0+ffowYsSIvy0+eWVkZERERATdunWjWrVq7N27l927dzNu3Djl2Wr+/v7Uq1eP8ePHk5CQgLu7O/v372fnzp0MHTpUGbXKjbGxMWXKlGHz5s2ULl0aKysrypUrR7ly5fJ8fvKqVKlSjB8/nsmTJ1O7dm1at26NoaEhsbGx2NvbM3369Dy/fyRJkj4m2WmTJEmSPnkdO3bE3t6esLAwZs2aRWpqKkWKFKF27dp0795dKbdx40YGDRrEkiVLEELQuHFj9u7di729vUZ9VapUYfLkyXzzzTdERESgVqu5ffs2pqamBAUF8fjxY77//nu2bNlCkyZN2Lt3L7a2tnlu75gxYyhdujTz5s0jNDQUeLtgSOPGjWnevPnfE5QPpKurS0REBP369WPkyJGYm5sTHBysMVVRR0eHXbt2ERQUxObNm1mzZg3Fixdn1qxZfPXVV3ne18qVKxk0aBDDhg0jLS2N4OBgypUrl+fz8yEmTZqEk5MTixYtYvz48ZiYmFChQgW6dOmilMnr+0eSJOljUYl/405kSZIkSZLyrcDAQL7//nuSkpI+dlMkSZKkbMh72iRJkiRJkiRJkvIx2WmTJEmSJEmSJEnKx2SnTZIkSZIkSZIkKR+T97RJkiRJkiRJkiTlY3KkTZIkSZIkSZIkKR+TnTZJkiRJkiRJkqR8TD6nTZI+c2q1mvv372Nubo5KpfrYzZEkSZIkKQ+EELx69Qp7e3t0dOQ4zH+d7LRJ0mfu/v37ODg4fOxmSJIkSZL0J9y9e5eiRYt+7GZI/zDZaZO
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Gain\n",
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 17,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAHHCAYAAAAMBu+WAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxO6f/48dfdvq+KMlFIhSQ72UXJGHt2MmGMQox1bGUrex/LYCxlHbPZZqwhWwgZxpCsyaDJkki0nt8f/TpftxYh6309H4/7wX3Oda5zvc9d3ec65zrXWyFJkoQgCIIgCIIgCCpF7UM3QBAEQRAEQRCE9090BARBEARBEARBBYmOgCAIgiAIgiCoINEREARBEARBEAQVJDoCgiAIgiAIgqCCREdAEARBEARBEFSQ6AgIgiAIgiAIggoSHQFBEARBEARBUEGiIyAIgiAIgiAIKkh0BARBEAThExIeHo5CoSA+Pv5DN0UQhE+c6AgIgiAIH7W8E9+CXuPGjXsn+zx27BiBgYE8evTondSvytLS0ggMDOTgwYMfuimCoPI0PnQDBEEQBKE4pk6dip2dndKyatWqvZN9HTt2jKCgIHx8fDAxMXkn+3hTffr0oXv37mhra3/opryRtLQ0goKCAGjWrNmHbYwgqDjRERAEQRA+CW3atKF27dofuhlv5enTp+jr679VHerq6qirq5dQi96fnJwcMjIyPnQzBEF4gRgaJAiCIHwWdu3aRePGjdHX18fQ0JC2bdty4cIFpTJ///03Pj4+VKhQAR0dHcqUKcPXX3/NgwcP5DKBgYGMHj0aADs7O3kYUnx8PPHx8SgUCsLDw/PtX6FQEBgYqFSPQqHg4sWL9OzZE1NTUxo1aiSvX79+PbVq1UJXVxczMzO6d+/OrVu3XhlnQc8I2Nra8uWXX3Lw4EFq166Nrq4uzs7O8vCbzZs34+zsjI6ODrVq1eKvv/5SqtPHxwcDAwOuX7+Oh4cH+vr6WFtbM3XqVCRJUir79OlTvvvuO2xsbNDW1sbBwYG5c+fmK6dQKPD392fDhg1UrVoVbW1tli1bhoWFBQBBQUHysc07bsX5fF48tlevXpXv2hgbG9O/f3/S0tLyHbP169dTt25d9PT0MDU1pUmTJuzdu1epTHF+fgThcyPuCAiCIAifhJSUFO7fv6+0rFSpUgCsW7eOfv364eHhwaxZs0hLS2Pp0qU0atSIv/76C1tbWwAiIiK4fv06/fv3p0yZMly4cIEff/yRCxcucOLECRQKBZ06deLy5cv89NNPLFiwQN6HhYUF9+7de+12d+3aFXt7e2bOnCmfLM+YMYNJkybh7e3NgAEDuHfvHosWLaJJkyb89ddfbzQc6erVq/Ts2ZNvvvmG3r17M3fuXNq1a8eyZcv4/vvvGTJkCADBwcF4e3sTFxeHmtr/XQ/Mzs7G09OT+vXrM3v2bHbv3s2UKVPIyspi6tSpAEiSxFdffUVkZCS+vr7UqFGDPXv2MHr0aG7fvs2CBQuU2nTgwAF++eUX/P39KVWqFC4uLixdupRvv/2Wjh070qlTJwCqV68OFO/zeZG3tzd2dnYEBwdz5swZVq5ciaWlJbNmzZLLBAUFERgYSMOGDZk6dSpaWlpER0dz4MABWrduDRT/50cQPjuSIAiCIHzEwsLCJKDAlyRJ0pMnTyQTExNp4MCBStslJiZKxsbGSsvT0tLy1f/TTz9JgHT48GF52Zw5cyRAunHjhlLZGzduSIAUFhaWrx5AmjJlivx+ypQpEiD16NFDqVx8fLykrq4uzZgxQ2n5+fPnJQ0NjXzLCzseL7atfPnyEiAdO3ZMXrZnzx4JkHR1daWbN2/Ky5cvXy4BUmRkpLysX79+EiANHTpUXpaTkyO1bdtW0tLSku7duydJkiRt3bpVAqTp06crtalLly6SQqGQrl69qnQ81NTUpAsXLiiVvXfvXr5jlae4n0/esf3666+Vynbs2FEyNzeX31+5ckVSU1OTOnbsKGVnZyuVzcnJkSTp9X5+BOFzI4YGCYIgCJ+EJUuWEBERofSC3KvIjx49okePHty/f19+qaurU69ePSIjI+U6dHV15f8/f/6c+/fvU79+fQDOnDnzTto9ePBgpfebN28mJycHb29vpfaWKVMGe3t7pfa+jipVqtCgQQP5fb169QBo0aIF5cqVy7f8+vXr+erw9/eX/583tCcjI4N9+/YBsHPnTtTV1Rk2bJjSdt999x2SJLFr1y6l5U2bNqVKlSrFjuF1P5+Xj23jxo158OABjx8/BmDr1q3k5OQwefJkpbsfefHB6/38CMLnRgwNEgRBED4JdevWLfBh4StXrgC5J7wFMTIykv//8OFDgoKC2LRpE0lJSUrlUlJSSrC1/+flmY6uXLmCJEnY29sXWF5TU/ON9vPiyT6AsbExADY2NgUuT05OVlqupqZGhQoVlJZVrlwZQH4e4ebNm1hbW2NoaKhUzsnJSV7/opdjf5XX/XxejtnU1BTIjc3IyIhr166hpqZWZGfkdX5+BOFzIzoCgiAIwictJycHyB3nXaZMmXzrNTT+76vO29ubY8eOMXr0aGrUqIGBgQE5OTl4enrK9RTl5THqebKzswvd5sWr3HntVSgU7Nq1q8DZfwwMDF7ZjoIUNpNQYcullx7ufRdejv1VXvfzKYnYXufnRxA+N+KnWxAEQfikVaxYEQBLS0vc3d0LLZecnMz+/fsJCgpi8uTJ8vK8K8IvKuyEP++K88uJxl6+Ev6q9kqShJ2dnXzF/WOQk5PD9evXldp0+fJlAPlh2fLly7Nv3z6ePHmidFfg0qVL8vpXKezYvs7nU1wVK1YkJyeHixcvUqNGjULLwKt/fgThcySeERAEQRA+aR4eHhgZGTFz5kwyMzPzrc+b6Sfv6vHLV4tDQ0PzbZM31//LJ/xGRkaUKlWKw4cPKy3/4Ycfit3eTp06oa6uTlBQUL62SJKUb6rM92nx4sVKbVm8eDGampq0bNkSAC8vL7Kzs5XKASxYsACFQkGbNm1euQ89PT0g/7F9nc+nuDp06ICamhpTp07Nd0chbz/F/fkRhM+RuCMgCIIgfNKMjIxYunQpffr0oWbNmnTv3h0LCwsSEhLYsWMHbm5uLF68GCMjI5o0acLs2bPJzMykbNmy7N27lxs3buSrs1atWgBMmDCB7t27o6mpSbt27dDX12fAgAGEhIQwYMAAateuzeHDh+Ur58VRsWJFpk+fzvjx44mPj6dDhw4YGhpy48YNtmzZwqBBgxg1alSJHZ/i0tHRYffu3fTr14969eqxa9cuduzYwffffy/P/d+uXTuaN2/OhAkTiI+Px8XFhb1797Jt2zYCAgLkq+tF0dXVpUqVKvz8889UrlwZMzMzqlWrRrVq1Yr9+RRXpUqVmDBhAtOmTaNx48Z06tQJbW1tTp06hbW1NcHBwcX++RGEz9IHmq1IEARBEIolb7rMU6dOFVkuMjJS8vDwkIyNjSUdHR2pYsWKko+Pj3T69Gm5zL///it17NhRMjExkYyNjaWuXbtKd+7cKXA6y2nTpklly5aV1NTUlKbrTEtLk3x9fSVjY2PJ0NBQ8vb2lpKSkgqdPjRv6s2X/f7771KjRo0kfX19SV9fX3J0dJT8/PykuLi4Yh2Pl6cPbdu2bb6ygOTn56e0LG8K1Dlz5sjL+vXrJ+nr60vXrl2TWrduLenp6UmlS5eWpkyZkm/azSdPnkgjRoyQrK2tJU1NTcne3l6aM2eOPB1nUfvOc+zYMalWrVqSlpaW0nEr7udT2LEt6NhIkiStXr1acnV1lbS1tSVTU1OpadOmUkREhFKZ4vz8CMLnRiFJ7+FpIUEQBEEQPlo+Pj789ttvpKamfuimCILwHolnBARBEARBEARBBYmOgCAIgiAIgiCoINEREARBEARBEAQVJJ4REARBEARBEAQVJO4ICIIgCIIgCIIKEh0BQRAEQRAEQVBBIqGYIKi4nJwc7ty5g6GhIQqF4kM3RxAEQRCEYpAkiSdPnmBtbY2a2ptd2xcdAUFQcXfu3MHGxuZDN0MQBEEQhDdw69YtvvjiizfaVnQ
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Weight\n",
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 18,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Calculating MAE, RMSE and R2 for training and test sets \n",
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
"r2_train = r2_score(y_train, y_pred_train)\n",
"\n",
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
"r2_test = r2_score(y_test, y_pred_test)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 19,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_e5682_row0_col0, #T_e5682_row0_col1, #T_e5682_row0_col2, #T_e5682_row0_col3, #T_e5682_row0_col4, #T_e5682_row0_col5, #T_e5682_row0_col6 {\n",
2023-12-12 15:22:01 +01:00
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
"<table id=\"T_e5682\">\n",
2023-12-12 15:22:01 +01:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_e5682_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
" <th id=\"T_e5682_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
" <th id=\"T_e5682_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
" <th id=\"T_e5682_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
" <th id=\"T_e5682_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
" <th id=\"T_e5682_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
" <th id=\"T_e5682_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_e5682_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_e5682_row0_col0\" class=\"data row0 col0\" >0.09934</td>\n",
" <td id=\"T_e5682_row0_col1\" class=\"data row0 col1\" >0.31518</td>\n",
" <td id=\"T_e5682_row0_col2\" class=\"data row0 col2\" >0.03828</td>\n",
" <td id=\"T_e5682_row0_col3\" class=\"data row0 col3\" >0.10366</td>\n",
" <td id=\"T_e5682_row0_col4\" class=\"data row0 col4\" >0.32197</td>\n",
" <td id=\"T_e5682_row0_col5\" class=\"data row0 col5\" >0.01251</td>\n",
" <td id=\"T_e5682_row0_col6\" class=\"data row0 col6\" >31.93709</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x20007e22020>"
2023-12-12 15:22:01 +01:00
]
},
"execution_count": 19,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training MAE': mae_train, \n",
" 'Training RMSE': rmse_train,\n",
" 'Training R2':r2_train,\n",
" 'Testing MAE': mae_test, \n",
" 'Testing RMSE': rmse_test,\n",
" 'Testing R2':r2_test,\n",
" 'Training Time (mins)': xgb_training_time/60})\n",
2023-12-12 15:22:01 +01:00
"summary_df.set_index('Model Name', inplace=True)\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Displaying summary of results\n",
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
2023-12-12 15:22:01 +01:00
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keeping the xgboost model"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 20,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"dump(best_xgb_model, 'xgboost.joblib') \n",
2023-12-12 15:22:01 +01:00
"\n",
"# Load the model\n",
"model = load('xgboost.joblib')"
2023-12-12 15:22:01 +01:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}