fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb

776 lines
361 KiB
Plaintext
Raw Normal View History

2023-12-12 15:22:01 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 7,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
"import xgboost\n",
"import time\n",
2024-01-12 20:39:14 +01:00
"from joblib import dump, load\n",
"import os"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the data"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 14,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-01-12 20:39:14 +01:00
"df = pd.read_csv('final_data.csv')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 15,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
2024-01-12 20:39:14 +01:00
" 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n",
" 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n",
" 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n",
" 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n",
" 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n",
" 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n",
" 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n",
" 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n",
" 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n",
" 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n",
" 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n",
" 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n",
" 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n",
" 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n",
" 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n",
" 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n",
" 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n",
" 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n",
" 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n",
" 'y_player_teammate_Goalkeeper'],\n",
" dtype='object')"
]
},
2024-01-12 20:39:14 +01:00
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
2023-12-12 15:22:01 +01:00
"source": [
"df.columns"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 16,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
2024-01-12 20:39:14 +01:00
" <th>shot_follows_dribble</th>\n",
2023-12-12 15:22:01 +01:00
" <th>...</th>\n",
2024-01-12 20:39:14 +01:00
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>Center Midfield</td>\n",
" <td>Right Foot</td>\n",
" <td>Half Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-12 20:39:14 +01:00
" <td>48.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>Left Center Midfield</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Right Center Back</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-01-12 20:39:14 +01:00
"<p>5 rows × 64 columns</p>\n",
2023-12-12 15:22:01 +01:00
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
"0 0 Right Center Forward Right Foot Normal \n",
"1 5 Right Center Forward Left Foot Normal \n",
"2 5 Center Midfield Right Foot Half Volley \n",
"3 5 Left Center Midfield Right Foot Normal \n",
"4 5 Right Center Back Left Foot Normal \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
"0 Open Play False False False \n",
"1 Open Play False False False \n",
"2 Open Play True False False \n",
"3 Open Play False False False \n",
"4 Open Play True False False \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
"1 False False ... NaN \n",
"2 False False ... 48.9 \n",
"3 False False ... NaN \n",
"4 False False ... NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"\n",
"[5 rows x 64 columns]"
2023-12-12 15:22:01 +01:00
]
},
2024-01-12 20:39:14 +01:00
"execution_count": 16,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 17,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
2023-12-12 15:22:01 +01:00
"source": [
"# Change the type of categorical features to 'category' \n",
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']] = df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']].astype('category')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 18,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Splitting the dataset into features (X) and the target variable (y)\n",
"y = pd.DataFrame(df['is_goal'])\n",
"X = df.drop(['is_goal'], axis=1)\n",
"\n",
"# Splitting the data into a training set and a test set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
"\n",
"# Create cross-validation \n",
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 19,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2023-12-12 15:22:01 +01:00
"output_type": "stream",
"text": [
"Shots attempted in the training set: 27085\n",
"Goals scored in the training set: 3588\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"\n",
"# Display the count of shots attempted in the training set\n",
"print('Shots attempted in the training set:', count_class_0)\n",
"\n",
"# Display the count of successful goals in the training set\n",
"print('Goals scored in the training set:', count_class_1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 20,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Class imbalance in training data: 7.549\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Class imbalance in training data\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training XGBoost model "
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 21,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 22,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
" 'max_depth': [3, 5, 7, 8, 9],\n",
" 'n_estimators': [100, 150, 200, 250, 300],\n",
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 23,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Starting the timer\n",
"start_time = time.time()\n",
"\n",
"# Perform grid search with cross-validation\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"\n",
"# Take the best parameters for xgboost model\n",
"best_xgb_model = grid_xg.best_estimator_\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 24,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-12 20:39:14 +01:00
"Best parameters: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
"Model Training Time: 912.022 seconds\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model evaluation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 25,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-12 20:39:14 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKbklEQVR4nO3deXwN9/7H8fdJyBEkISWJWINaUluLRiiqVKwVdFFdQtHSaEtQ1asEbXOp1r50FV20qi1VWppSclV0SZtaimutWhKxRCRIIpnfH37OdcSSOGfk4PXsY+51Zr5n5juThHc+3+/MsRiGYQgAAMCFuRV1BwAAAK6GwAIAAFwegQUAALg8AgsAAHB5BBYAAODyCCwAAMDlEVgAAIDLI7AAAACXR2ABAAAuj8CCG8aOHTvUvn17+fj4yGKxaMmSJU7d/969e2WxWBQbG+vU/d7I7r33Xt17771F3Q1T8PUGbiwEFhTKrl279Mwzz6h69eoqUaKEvL291aJFC02bNk2nT5829dgRERHatGmTXnvtNX300Udq0qSJqce7nvr06SOLxSJvb+9LXscdO3bIYrHIYrFo8uTJhd7/wYMHFR0draSkJCf01lzR0dG2c73S4opBKi8vTx9++KFCQkLk6+srLy8v1apVS08++aQ2bNhQ6P2dOnVK0dHRWrNmjfM7C9xgihV1B3DjWL58uR566CFZrVY9+eSTqlevnrKzs7Vu3TqNGDFCW7Zs0TvvvGPKsU+fPq2EhAT961//0uDBg005RtWqVXX69GkVL17clP1fTbFixXTq1Cl98803evjhh+22ffLJJypRooTOnDlzTfs+ePCgxo0bp2rVqqlRo0YFft/3339/TcdzRI8ePVSzZk3b64yMDA0aNEjdu3dXjx49bOv9/f0dOo4ZX+/nn39es2bNUrdu3fTYY4+pWLFi2r59u7777jtVr15dzZo1K9T+Tp06pXHjxkmSSwY04HoisKBA9uzZo169eqlq1apavXq1KlSoYNsWGRmpnTt3avny5aYdPzU1VZJUpkwZ045hsVhUokQJ0/Z/NVarVS1atNCnn36aL7AsWLBAnTt31pdffnld+nLq1CmVLFlSHh4e1+V4F2rQoIEaNGhge33kyBENGjRIDRo00OOPP37Z9505c0YeHh5ycytY4djZX++UlBTNnj1bAwYMyBfcp06davseBnBtGBJCgUyaNEkZGRl6//337cLKeTVr1tQLL7xge3327FlNmDBBNWrUkNVqVbVq1fTyyy8rKyvL7n3VqlVTly5dtG7dOt19990qUaKEqlevrg8//NDWJjo6WlWrVpUkjRgxQhaLRdWqVZN0bijl/J8vdH5Y4UJxcXG65557VKZMGZUuXVq1a9fWyy+/bNt+uTkNq1evVsuWLVWqVCmVKVNG3bp109atWy95vJ07d6pPnz4qU6aMfHx81LdvX506deryF/YivXv31nfffae0tDTbul9//VU7duxQ796987U/duyYhg8frvr166t06dLy9vZWx44d9eeff9rarFmzRk2bNpUk9e3b1zakcv487733XtWrV0+JiYlq1aqVSpYsabsuF89hiYiIUIkSJfKdf1hYmMqWLauDBw8W+FwdsWbNGlksFn322WcaPXq0KlasqJIlSyo9Pb1A10S69Ne7T58+Kl26tA4cOKDw8HCVLl1a5cuX1/Dhw5Wbm3vFPu3Zs0eGYahFixb5tlksFvn5+dmtS0tL05AhQ1S5cmVZrVbVrFlTEydOVF5enq1/5cuXlySNGzfO9nWLjo6+hisG3PiosKBAvvnmG1WvXl3NmzcvUPv+/ftr/vz5evDBBzVs2DD9/PPPiomJ0datW7V48WK7tjt37tSDDz6ofv36KSIiQh988IH69Omjxo0b64477lCPHj1UpkwZDR06VI8++qg6deqk0qVLF6r/W7ZsUZcuXdSgQQONHz9eVqtVO3fu1E8//XTF9/3www/q2LGjqlevrujoaJ0+fVozZsxQixYt9Pvvv+cLSw8//LCCgoIUExOj33//Xe+99578/Pw0ceLEAvWzR48eGjhwoL766is99dRTks5VV+rUqaO77rorX/vdu3dryZIleuihhxQUFKSUlBS9/fbbat26tf766y8FBgaqbt26Gj9+vMaMGaOnn35aLVu2lCS7r+XRo0fVsWNH9erVS48//vhlh1umTZum1atXKyIiQgkJCXJ3d9fbb7+t77//Xh999JECAwMLdJ7OMmHCBHl4eGj48OHKysqSh4eH/vrrr6tekyvJzc1VWFiYQkJCNHnyZP3www968803VaNGDQ0aNOiy7zsfqhctWqSHHnpIJUuWvGzbU6dOqXXr1jpw4ICeeeYZValSRevXr9eoUaN06NAhTZ06VeXLl9ecOXPyDYddWH0CbikGcBUnTpwwJBndunUrUPukpCRDktG/f3+79cOHDzckGatXr7atq1q1qiHJiI+Pt607fPiwYbVajWHDhtnW7dmzx5BkvPHGG3b7jIiIMKpWrZqvD2PHjjUu/PaeMmWKIclITU29bL/PH2PevHm2dY0aNTL8/PyMo0eP2tb9+eefhpubm/Hkk0/mO95TTz1lt8/u3bsbt91222WPeeF5lCpVyjAMw3jwwQeNtm3bGoZhGLm5uUZAQIAxbty4S16DM2fOGLm5ufnOw2q1GuPHj7et+/XXX/Od23mtW7c2JBlz58695LbWrVvbrVu5cqUhyXj11VeN3bt3G6VLlzbCw8Oveo7XKjU11ZBkjB071rbuxx9/NCQZ1atXN06dOmXXvqDX5FJf74iICEOSXTvDMIw777zTaNy48VX7+uSTTxqSjLJlyxrdu3c3Jk+ebGzdujVfuwkTJhilSpUy/vvf/9qtf+mllwx3d3dj3759lz134FbFkBCuKj09XZLk5eVVoPbffvutJCkqKspu/bBhwyQp31yX4OBg22/9klS+fHnVrl1bu3fvvuY+X+z83Jevv/7aVnK/mkOHDikpKUl9+vSRr6+vbX2DBg10//33287zQgMHDrR73bJlSx09etR2DQuid+/eWrNmjZKTk7V69WolJydfcjhIOjfv5fycjdzcXB09etQ23PX7778X+JhWq1V9+/YtUNv27dvrmWee0fjx49WjRw+VKFFCb7/9doGP5UwRERHy9PS0W+eMa3Kpr2NBvh/nzZunmTNnKigoSIsXL9bw4cNVt25dtW3bVgcOHLC1W7RokVq2bKmyZcvqyJEjtqVdu3bKzc1VfHx8gfoJ3EoILLgqb29vSdLJkycL1P7vv/+Wm5ub3Z0ekhQQEKAyZcro77//tltfpUqVfPsoW7asjh8/fo09zu+RRx5RixYt1L9/f/n7+6tXr176/PPPrxhezvezdu3a+bbVrVtXR44cUWZmpt36i8+lbNmyklSoc+nUqZO8vLy0cOFCffLJJ2ratGm+a3leXl6epkyZottvv11Wq1XlypVT+fLltXHjRp04caLAx6xYsWKhJthOnjxZvr6+SkpK0vTp0/PNz7iU1NRUJScn25aMjIwCH+9ygoKC8q1z9JqUKFHCNnfkvIJ+P7q5uSkyMlKJiYk6cuSIvv76a3Xs2FGrV69Wr169bO127NihFStWqHz58nZLu3btJEmHDx++6rGAWw2BBVfl7e2twMBAbd68uVDvu3jS6+W4u7tfcr1hGNd8jIsnSHp6eio+Pl4//PCDnnjiCW3cuFGPPPKI7r///qtOpiwMR87lPKvVqh49emj+/PlavHjxZasrkvT6668rKipKrVq10scff6yVK1cqLi5Od9xxR4ErSZLyVSmu5o8//rD9o7pp06YCvadp06aqUKGCbbmW58lc7FL9dvSaXO5rWFi33XabHnjgAX377bdq3bq11q1bZwvBeXl5uv/++xUXF3fJpWfPnk7pA3AzYdItCqRLly565513lJCQoNDQ0Cu2rVq1qvLy8rRjxw7VrVvXtj4lJUVpaWm2yYnOULZsWbs7as67uIojnfvtt23btmrbtq3eeustvf766/rXv/6lH3/80fab7cXnIUnbt2/Pt23btm0qV66cSpUq5fh
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on training set\n",
"y_pred_train = best_xgb_model.predict(X_train)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 26,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-12 20:39:14 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAHHCAYAAACPy0PBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJnElEQVR4nO3deVxU1f/H8fegMCKrqICkIuZKbomm5L4kGpprZVniVukXLXelxbXEbDHNLVvUSkuzNNNy+WpqJZpRlLnlVlQKbiGJCij394c/5+sEOmBzHcXXs8d9PJx7z9x77uTy5nPOuWMxDMMQAACAC7m5ugMAAAAEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEty09u3bpzZt2sjPz08Wi0XLly936vl//fVXWSwWzZ8/36nnvZk1b95czZs3d3U3ABRCBBL8KwcOHNATTzyhihUrqlixYvL19VWjRo00bdo0nT171tRrx8TEaMeOHXrhhRf03nvvqV69eqZe73rq1auXLBaLfH198/wc9+3bJ4vFIovFopdffrnA5z98+LDGjRunpKQkJ/TWXOPGjbPd69U2ZwWlzz//XOPGjct3+5ycHL377rtq0KCBAgIC5OPjoypVqqhnz57aunVrga9/5swZjRs3Ths3bizwe4GbWVFXdwA3r1WrVun++++X1WpVz549VaNGDWVlZenrr7/WiBEjtHPnTs2dO9eUa589e1YJCQl65plnNHDgQFOuERoaqrNnz8rd3d2U8ztStGhRnTlzRp999pkeeOABu2MLFy5UsWLFdO7cuWs69+HDhzV+/HhVqFBBderUyff71q5de03X+ze6dOmiSpUq2V6fPn1aAwYMUOfOndWlSxfb/qCgIKdc7/PPP9fMmTPzHUqefPJJzZw5Ux07dlSPHj1UtGhR7d27V1988YUqVqyohg0bFuj6Z86c0fjx4yWJahRuKQQSXJNDhw6pe/fuCg0N1YYNG1SmTBnbsdjYWO3fv1+rVq0y7frHjh2TJPn7+5t2DYvFomLFipl2fkesVqsaNWqkDz74IFcgWbRokaKjo/Xxxx9fl76cOXNGxYsXl4eHx3W53uVq1aqlWrVq2V4fP35cAwYMUK1atfTII49c9/5cLjU1VbNmzdJjjz2WK3y/9tprtt+nABxjyAbXZMqUKTp9+rTefvttuzBySaVKlfTUU0/ZXp8/f14TJ07U7bffLqvVqgoVKujpp59WZmam3fsqVKig9u3b6+uvv9Zdd92lYsWKqWLFinr33XdtbcaNG6fQ0FBJ0ogRI2SxWFShQgVJF4c6Lv36cpfK/pdbt26dGjduLH9/f3l7e6tq1ap6+umnbcevNIdkw4YNatKkiby8vOTv76+OHTtq9+7deV5v//796tWrl/z9/eXn56fevXvrzJkzV/5g/+Hhhx/WF198obS0NNu+7du3a9++fXr44YdztT958qSGDx+umjVrytvbW76+vmrXrp1+/PFHW5uNGzeqfv36kqTevXvbhjwu3Wfz5s1Vo0YNJSYmqmnTpipevLjtc/nnHJKYmBgVK1Ys1/1HRUWpRIkSOnz4cL7v9d/as2ePunXrpoCAABUrVkz16tXTihUr7NpkZ2dr/Pjxqly5sooVK6aSJUuqcePGWrdunaSLv39mzpwpSXbDQVdy6NAhGYahRo0a5TpmsVgUGBhoty8tLU2DBw9WuXLlZLVaValSJb344ovKycmRdPH3XOnSpSVJ48ePt12/IENIwM2KCgmuyWeffaaKFSvq7rvvzlf7fv36acGCBerWrZuGDRumbdu2KT4+Xrt379ayZcvs2u7fv1/dunVT3759FRMTo3feeUe9evVSRESE7rjjDnXp0kX+/v4aMmSIHnroId17773y9vYuUP937typ9u3bq1atWpowYYKsVqv279+vb7755qrv++9//6t27dqpYsWKGjdunM6ePavXX39djRo10vfff58rDD3wwAMKCwtTfHy8vv/+e7311lsKDAzUiy++mK9+dunSRf3799cnn3yiPn36SLpYHalWrZrq1q2bq/3Bgwe1fPly3X///QoLC1NqaqreeOMNNWvWTLt27VJISIiqV6+uCRMmaMyYMXr88cfVpEkTSbL7f3nixAm1a9dO3bt31yOPPHLF4ZBp06Zpw4YNiomJUUJCgooUKaI33nhDa9eu1XvvvaeQkJB83ee/tXPnTjVq1Ei33XabRo8eLS8vLy1ZskSdOnXSxx9/rM6dO0u6GBTj4+PVr18/3XXXXUpPT9d3332n77//Xvfcc4+eeOIJHT58WOvWrdN7773n8LqXgvFHH32k+++/X8WLF79i2zNnzqhZs2b6888/9cQTT6h8+fLasmWL4uLidOTIEb322msqXbq0Zs+enWtI6vIKEVBoGUABnTp1ypBkdOzYMV/tk5KSDElGv3797PYPHz7ckGRs2LDBti80NNSQZGzevNm27+jRo4bVajWGDRtm23fo0CFDkvHSSy/ZnTMmJsYIDQ3N1YexY8cal/92nzp1qiHJOHbs2BX7feka8+bNs+2rU6eOERgYaJw4ccK278cffzTc3NyMnj175rpenz597M7ZuXNno2TJkle85uX34eXlZRiGYXTr1s1o1aqVYRiGceHCBSM4ONgYP358np/BuXPnjAsXLuS6D6vVakyYMMG2b/v27bnu7ZJmzZoZkow5c+bkeaxZs2Z2+9asWWNIMp5//nnj4MGDhre3t9GpUyeH93itjh07Zkgyxo4da9vXqlUro2bNmsa5c+ds+3Jycoy7777bqFy5sm1f7dq1jejo6KuePzY21ijIX409e/Y0JBklSpQwOnfubLz88svG7t27c7WbOHGi4eXlZfzyyy92+0ePHm0UKVLESE5OvuL9AbcChmxQYOnp6ZIkHx+ffLX//PPPJUlDhw612z9s2DBJyjXXJDw83PZTuySVLl1aVatW1cGDB6+5z/90ae7Jp59+aiuXO3LkyBElJSWpV69eCggIsO2vVauW7rnnHtt9Xq5///52r5s0aaITJ07YPsP8ePjhh7Vx40alpKRow4YNSklJyXO4Rro478TN7eIf6wsXLujEiRO24ajvv/8+39e0Wq3q3bt3vtq2adNGTzzxhCZMmKAuXbqoWLFieuONN/J9rX/r5MmT2rBhgx544AH9/fffOn78uI4fP64TJ04oKipK+/bt059//inp4v/3nTt3at++fU67/rx58zRjxgyFhYVp2bJlGj58uKpXr65WrVrZritdrKI0adJEJUqUsPXx+PHjat26tS5cuKDNmzc7rU/AzYhAggLz9fWVJP3999/5av/bb7/Jzc3NbqWEJAUHB8vf31+//fab3f7y5cvnOkeJEiX0119/XWOPc3vwwQfVqFEj9evXT0FBQerevbuWLFly1XByqZ9Vq1bNdax69eo6fvy4MjIy7Pb/815KlCghSQW6l3vvvVc+Pj5avHixFi5cqPr16+f6LC/JycnR1KlTVblyZVmtVpUqVUqlS5fWTz/9pFOnTuX7mrfddluBJrC+/PLLCggIUFJSkqZPn55r7kRejh07ppSUFNt2+vTpfF/vcvv375dhGHruuedUunRpu23s2LGSpKNHj0qSJkyYoLS0NFWpUkU1a9bUiBEj9NNPP13TdS9xc3NTbGysEhMTdfz4cX366adq166dNmzYoO7du9va7du3T6tXr87Vx9atW9v1EbhVMYcEBebr66uQkBD9/PPPBXrf1SYHXq5IkSJ57jcM45qvceHCBbvXnp6e2rx5s7788kutWrVKq1ev1uLFi9WyZUutXbv2in0oqH9zL5dYrVZ16dJFCxYs0MGDB686wXHSpEl67rnn1KdPH02cOFEBAQFyc3PT4MGD810Jki5+PgXxww8/2P5B3bFjhx566CGH76lfv75dGB07duw1Td68dF/Dhw9XVFRUnm0uBbimTZvqwIED+vTTT7V27Vq99dZbmjp1qubMmaN+/foV+Nr/VLJkSd13332677771Lx5c23atEm//fabQkNDlZOTo3vuuUc
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on test set\n",
"y_pred_test = best_xgb_model.predict(X_test)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 27,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The test dataset contains 7669 shots, with 914 of them being goals.\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Number of goals in test set\n",
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
2023-12-12 15:22:01 +01:00
"metadata": {},
"source": [
"## Feature importance"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 28,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-12 20:39:14 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAHHCAYAAADONqsSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxP2f8H8Nen/dOqUipaabOVrVTI0qYmlSWyZpIxdgZZ0mbLHrIvhWEsgzQjSyhLkiRLIkpESqRFpfrU5/z+6Nf9dn0+JTNUzHk+Hj3G55xzzz33fe987ufce+65HEIIAUVRFEVRFEVRFNViiDR3AyiKoiiKoiiKoig22lGjKIqiKIqiKIpqYWhHjaIoiqIoiqIoqoWhHTWKoiiKoiiKoqgWhnbUKIqiKIqiKIqiWhjaUaMoiqIoiqIoimphaEeNoiiKoiiKoiiqhaEdNYqiKIqiKIqiqBaGdtQoiqIoiqIoiqJaGNpRoyiKoqjvSHh4ODgcDp4/f97cTaEoiqK+IdpRoyiKolq02o6JsL+FCxd+k3XeuHEDAQEBKCws/Cb1/5eVlZUhICAAsbGxzd0UiqKoFk2suRtAURRFUY0RFBQEXV1dVlrnzp2/ybpu3LiBwMBAeHp6olWrVt9kHf/UuHHjMGrUKEhKSjZ3U/6RsrIyBAYGAgD69+/fvI2hKIpqwWhHjaIoivouDB48GD179mzuZvwrpaWlkJGR+Vd1iIqKQlRU9Cu1qOnw+XxUVlY2dzMoiqK+G3ToI0VRFPVDOHv2LPr27QsZGRnIycnByckJDx8+ZJW5f/8+PD09oaenBykpKaipqeHnn39Gfn4+UyYgIADz588HAOjq6jLDLJ8/f47nz5+Dw+EgPDxcYP0cDgcBAQGsejgcDlJTUzF69GgoKiqiT58+TP7vv/+OHj16gMvlQklJCaNGjcLLly8/u53CnlHT0dHBTz/9hNjYWPTs2RNcLhddunRhhheePHkSXbp0gZSUFHr06IHk5GRWnZ6enpCVlcWzZ89gb28PGRkZaGhoICgoCIQQVtnS0lL89ttv0NTUhKSkJAwNDbFu3TqBchwOB9OnT8ehQ4fQqVMnSEpKYseOHVBRUQEABAYGMrGtjVtj9k/d2KanpzN3PRUUFDBx4kSUlZUJxOz333+HmZkZpKWloaioiH79+uHChQusMo05fiiKopoSvaNGURRFfReKiorw7t07Vlrr1q0BAAcPHsSECRNgb2+P1atXo6ysDNu3b0efPn2QnJwMHR0dAEB0dDSePXuGiRMnQk1NDQ8fPsSuXbvw8OFD3Lx5ExwOB0OHDsWTJ0/wxx9/YOPGjcw6VFRU8Pbt2y9u94gRI6Cvr4+VK1cynZkVK1Zg6dKlcHd3x6RJk/D27Vts2bIF/fr1Q3Jy8j8abpmeno7Ro0fjl19+wdixY7Fu3To4Oztjx44dWLx4MaZOnQoAWLVqFdzd3ZGWlgYRkf9dr62uroaDgwN69+6NNWvW4Ny5c/D390dVVRWCgoIAAIQQDBkyBDExMfDy8oKpqSnOnz+P+fPnIzs7Gxs3bmS16fLlyzh27BimT5+O1q1bw8TEBNu3b8evv/4KNzc3DB06FADQtWtXAI3bP3W5u7tDV1cXq1atwp07d7Bnzx6oqqpi9erVTJnAwEAEBATA0tISQUFBkJCQQEJCAi5fvgw7OzsAjT9+KIqimhShKIqiqBYsLCyMABD6RwghHz58IK1atSLe3t6s5XJzc4mCggIrvaysTKD+P/74gwAgV69eZdLWrl1LAJDMzExW2czMTAKAhIWFCdQDgPj7+zOf/f39CQDi4eHBKvf8+XMiKipKVqxYwUp/8OABERMTE0ivLx5126atrU0AkBs3bjBp58+fJwAIl8slL168YNJ37txJAJCYmBgmbcKECQQAmTFjBpPG5/OJk5MTkZCQIG/fviWEEBIREUEAkOXLl7PaNHz4cMLhcEh6ejorHiIiIuThw4essm/fvhWIVa3G7p/a2P7888+ssm5ubkRZWZn5/PTpUyIiIkLc3NxIdXU1qyyfzyeEfNnxQ1EU1ZTo0EeKoijqu7B161ZER0ez/oCauzCFhYXw8PDAu3fvmD9RUVGYm5sjJiaGqYPL5TL/Li8vx7t379C7d28AwJ07d75Ju6dMmcL6fPLkSfD5fLi7u7Paq6amBn19fVZ7v0THjh1hYWHBfDY3NwcADBw4EFpaWgLpz549E6hj+vTpzL9rhy5WVlbi4sWLAICoqCiIiopi5syZrOV+++03EEJw9uxZVrq1tTU6duzY6G340v3zaWz79u2L/Px8FBcXAwAiIiLA5/Ph5+fHuntYu33Alx0/FEVRTYkOfaQoiqK+C2ZmZkInE3n69CmAmg6JMPLy8sy/379/j8DAQBw5cgR5eXmsckVFRV+xtf/z6UyVT58+BSEE+vr6QsuLi4v/o/XU7YwBgIKCAgBAU1NTaHpBQQErXUREBHp6eqw0AwMDAGCeh3vx4gU0NDQgJyfHKmdsbMzk1/Xptn/Ol+6fT7dZUVERQM22ycvLIyMjAyIiIg12Fr/k+KEoimpKtKNGURRFfdf4fD6AmueM1NTUBPLFxP53qnN3d8eNGzcwf/58mJqaQlZWFnw+Hw4ODkw9Dfn0Gala1dXV9S5T9y5RbXs5HA7Onj0rdPZGWVnZz7ZDmPpmgqwvnXwy+ce38Om2f86X7p+vsW1fcvxQFEU1JfrtQ1EURX3X2rdvDwBQVVWFjY1NveUKCgpw6dIlBAYGws/Pj0mvvaNSV30dsto7Np++CPvTO0mfay8hBLq6uswdq5aAz+fj2bNnrDY9efIEAJjJNLS1tXHx4kV8+PCBdVft8ePHTP7n1BfbL9k/jdW+fXvw+XykpqbC1NS03jLA548fiqKopkafUaMoiqK+a/b29pCXl8fKlSvB4/EE8mtnaqy9+/Lp3ZaQkBCBZWrfdfZph0xeXh6tW7fG1atXWenbtm1rdHuHDh0KUVFRBAYGCrSFECIwFX1TCg0NZbUlNDQU4uLiGDRoEADA0dER1dXVrHIAsHHjRnA4HAwePPiz65CWlgYgGNsv2T+N5erqChEREQQFBQnckatdT2OPH4qiqKZG76hRFEVR3zV5eXls374d48aNQ/fu3TFq1CioqKggKysLZ86cgZWVFUJDQyEvL49+/fphzZo14PF4aNu2LS5cuIDMzEyBOnv06AEAWLJkCUaNGgVxcXE4OztDRkYGkyZNQnBwMCZNmoSePXvi6tWrzJ2nxmjfvj2WL1+ORYsW4fnz53B1dYWcnBwyMzNx6tQpTJ48GfPmzftq8WksKSkpnDt3DhMmTIC5uTnOnj2LM2fOYPHixcy7z5ydnTFgwAAsWbIEz58/h4mJCS5cuIDTp09j9uzZzN2phnC5XHTs2BFHjx6FgYEBlJSU0LlzZ3Tu3LnR+6exOnTogCVLlmDZsmXo27cvhg4dCklJSSQmJkJDQwOrVq1q9PFDURTV1GhHjaIoivrujR49GhoaGggODsbatWtRUVGBtm3bom/fvpg4cSJT7vDhw5gxYwa2bt0KQgjs7Oxw9uxZaGhosOrr1asXli1bhh07duDcuXPg8/nIzMyEjIwM/Pz88PbtW/z55584duwYBg8ejLNnz0JVVbXR7V24cCEMDAywceNGBAYGAqiZ9MPOzg5Dhgz5OkH5QqKiojh37hx+/fVXzJ8/H3JycvD392cNQxQREUFkZCT8/Pxw9OhRhIWFQUdHB2vXrsVvv/3W6HXt2bMHM2bMwJw5c1BZWQl/f3907ty50fvnSwQFBUFXVxdbtmzBkiVLIC0tja5du2LcuHFMmcYePxRFUU2JQ5riaWKKoiiKolosT09P/PnnnygpKWnuplAURVH/jz6jRlEURVEURVEU1cLQjhpFURRFURRFUVQLQztqFEVRFEVRFEVRLQx9Ro2iKIqiKIqiKKqFoXfUKIqiKIqiKIqiWhjaUaMoiqIoiqIoimph6HvUKOo/js/n4/Xr15CTkwOHw2nu5lAURVEU1QiEEHz48AEaGhoQEaH3Xn5EtKNGUf9xr1+/hqamZnM
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Gain\n",
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 29,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-12 20:39:14 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvwAAAHHCAYAAADDDYx8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzde1yP5//A8deng45KJSmLHFIhMYRyCFGa5lwOQ6bMIdYMMaSczxptbDaFMd99tzn8JiatHGKNHIZaiGQOc0qp6Hj//ujR/fXRwSdzKK7n49FjPvd93dd93e/PZ3V9rvu6r7dCkiQJQRAEQRAEQRDeSGqvuwGCIAiCIAiCILw8osMvCIIgCIIgCG8w0eEXBEEQBEEQhDeY6PALgiAIgiAIwhtMdPgFQRAEQRAE4Q0mOvyCIAiCIAiC8AYTHX5BEARBEARBeIOJDr8gCIIgCIIgvMFEh18QBEEQBEEQ3mCiwy8IgiAI1UhERAQKhYLU1NTX3RRBEKoJ0eEXBEEQqrSSDm5ZPzNmzHgp5zx69CjBwcE8ePDgpdT/NsvJySE4OJjY2NjX3RRBeGtovO4GCIIgCIIq5s2bR8OGDZW2tWjR4qWc6+jRo4SEhODj40OtWrVeyjme14gRIxgyZAhaWlqvuynPJScnh5CQEABcXFxeb2ME4S0hOvyCIAhCtdC7d2/atm37upvxr2RnZ6Onp/ev6lBXV0ddXf0FtejVKSoqIi8v73U3QxDeSmJKjyAIgvBG2Lt3L507d0ZPT4+aNWvy3nvvcf78eaUyf/75Jz4+PjRq1AhtbW3q1q3Lhx9+yL179+QywcHBTJs2DYCGDRvK04dSU1NJTU1FoVAQERFR6vwKhYLg4GClehQKBYmJiQwbNgwjIyM6deok7//uu+9o06YNOjo6GBsbM2TIEK5du/bM6yxrDr+VlRV9+vQhNjaWtm3boqOjg729vTxt5ueff8be3h5tbW3atGnDqVOnlOr08fFBX1+fy5cv4+bmhp6eHhYWFsybNw9JkpTKZmdn8+mnn2JpaYmWlhY2NjasWLGiVDmFQoG/vz9bt26lefPmaGlpsX79ekxNTQEICQmRY1sSN1Xenydje+nSJfkujKGhIaNHjyYnJ6dUzL777jscHR3R1dXFyMiILl26sH//fqUyqnx+BKG6EiP8giAIQrWQkZHB3bt3lbbVrl0bgC1btjBq1Cjc3NxYunQpOTk5rFu3jk6dOnHq1CmsrKwAiIqK4vLly4wePZq6dety/vx5vv76a86fP8/vv/+OQqFgwIABXLhwge+//57Vq1fL5zA1NeXOnTuVbvfgwYOxtrZm0aJFcqd44cKFzJkzBy8vL3x9fblz5w5r166lS5cunDp16rmmEV26dIlhw4bx0Ucf8cEHH7BixQo8PT1Zv349n332GRMmTABg8eLFeHl5kZycjJra/8b9CgsLcXd3p0OHDixbtox9+/Yxd+5cCgoKmDdvHgCSJPH+++8TExPDmDFjaNWqFb/++ivTpk3j+vXrrF69WqlNv/32Gz/88AP+/v7Url0bBwcH1q1bx/jx4+nfvz8DBgwAoGXLloBq78+TvLy8aNiwIYsXL+bkyZN888031KlTh6VLl8plQkJCCA4OxsnJiXnz5lGjRg3i4+P57bff6NWrF6D650cQqi1JEARBEKqw8PBwCSjzR5Ik6eHDh1KtWrUkPz8/peNu3bolGRoaKm3PyckpVf/3338vAdKhQ4fkbcuXL5cA6cqVK0plr1y5IgFSeHh4qXoAae7cufLruXPnSoA0dOhQpXKpqamSurq6tHDhQqXtZ8+elTQ0NEptLy8eT7atQYMGEiAdPXpU3vbrr79KgKSjoyNdvXpV3v7VV19JgBQTEyNvGzVqlARIkyZNkrcVFRVJ7733nlSjRg3pzp07kiRJ0s6dOyVAWrBggVKbBg0aJCkUCunSpUtK8VBTU5POnz+vVPbOnTulYlVC1fenJLYffvihUtn+/ftLJiYm8uuLFy9KampqUv/+/aXCwkKlskVFRZIkVe7zIwjVlZjSIwiCIFQLX3zxBVFRUUo/UDwq/ODBA4YOHcrdu3flH3V1ddq3b09MTIxch46Ojvzvx48fc/fuXTp06ADAyZMnX0q7x40bp/T6559/pqioCC8vL6X21q1bF2tra6X2VkazZs3o2LGj/Lp9+/YAdO/enfr165fafvny5VJ1+Pv7y/8umZKTl5fHgQMHAIiMjERdXZ3JkycrHffpp58iSRJ79+5V2t61a1eaNWum8jVU9v15OradO3fm3r17ZGZmArBz506KiooICgpSuptRcn1Quc+PIFRXYkqPIAiCUC04OjqW+dDuxYsXgeKObVkMDAzkf9+/f5+QkBC2b9/O7du3lcplZGS8wNb+z9MrC128eBFJkrC2ti6zvKam5nOd58lOPYChoSEAlpaWZW5PT09X2q6mpkajRo2UtjVt2hRAfl7g6tWrWFhYULNmTaVydnZ28v4nPX3tz1LZ9+fpazYyMgKKr83AwICUlBTU1NQq/NJRmc+PIFRXosMvCIIgVGtFRUVA8TzsunXrltqvofG/P3VeXl4cPXqUadOm0apVK/T19SkqKsLd3V2upyJPzyEvUVhYWO4xT45al7RXoVCwd+/eMlfb0dfXf2Y7ylLeyj3lbZeeesj2ZXj62p+lsu/Pi7i2ynx+BKG6Ep9iQRAEoVpr3LgxAHXq1MHV1bXccunp6URHRxMSEkJQUJC8vWSE90nldexLRpCfTsj19Mj2s9orSRINGzaUR9CrgqKiIi5fvqzUpgsXLgDID602aNCAAwcO8PDhQ6VR/r/++kve/yzlxbYy74+qGjduTFFREYmJibRq1arcMvDsz48gVGdiDr8gCIJQrbm5uWFgYMCiRYvIz88vtb9kZZ2S0eCnR39DQ0NLHVOyVv7THXsDAwNq167NoUOHlLZ/+eWXKrd3wIABqKurExISUqotkiSVWoLyVQoLC1NqS1hYGJqamvTo0QMADw8PCgsLlcoBrF69GoVCQe/evZ95Dl1dXaB0bCvz/qiqX79+qKmpMW/evFJ3CErOo+rnRxCqMzHCLwiCIFRrBgYGrFu3jhEjRvDuu+8yZMgQTE1NSUtLY8+ePTg7OxMWFoaBgQFdunRh2bJl5OfnU69ePfbv38+VK1dK1dmmTRsAZs2axZAhQ9DU1MTT0xM9PT18fX1ZsmQJvr6+tG3blkOHDskj4apo3LgxCxYsYObMmaSmptKvXz9q1qzJlStX2LFjB2PHjmXq1KkvLD6q0tbWZt++fYwaNYr27duzd+9e9uzZw2effSavne/p6Um3bt2YNWsWqampODg4sH//fnbt2kVAQIA8Wl4RHR0dmjVrxn/+8x+aNm2KsbExLVq0oEWLFiq/P6pq0qQJs2bNYv78+XTu3JkBAwagpaXF8ePHsbCwYPHixSp/fgShWntNqwMJgiAIgkpKlqE8fvx4heViYmIkNzc3ydDQUNLW1pYaN24s+fj4SCdOnJDL/P3331L//v2lWrVqSYaGhtLgwYOlGzdulLlM5Pz586V69epJampqSstg5uTkSGPGjJEMDQ2lmjVrSl5eXtLt27fLXZazZEnLp/30009Sp06dJD09PUlPT0+ytbWVJk6cKCUnJ6sUj6eX5XzvvfdKlQWkiRMnKm0rWVp0+fLl8rZRo0ZJenp6UkpKitSrVy9JV1dXMjMzk+bOnVtqOcuHDx9Kn3zyiWRhYSFpampK1tbW0vLly+VlLis6d4mjR49Kbdq0kWrUqKEUN1Xfn/JiW1ZsJEmSNm7cKLVu3VrS0tKSjIyMpK5du0pRUVFKZVT5/AhCdaWQpFfw1I4gCIIgCFWWj48PP/74I1lZWa+7KYIgvARiDr8gCIIgCIIgvMFEh18QBEEQBEEQ3mCiwy8IgiAIgiAIbzAxh18QBEEQBEEQ3mBihF8QBEEQBEEQ3mCiwy8IgiAIgiAIbzCReEsQ3nJFRUXcuHGDmjVrlpvyXhAEQRCEqkWSJB4+fIiFhQVqahWP4YsOvyC85W7cuIGlpeXrboYgCIIgCM/h2rVrvPP
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Weight\n",
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 30,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Calculating MAE, RMSE and R2 for training and test sets \n",
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
"r2_train = r2_score(y_train, y_pred_train)\n",
"\n",
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
"r2_test = r2_score(y_test, y_pred_test)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 31,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-01-12 20:39:14 +01:00
"#T_57229_row0_col0, #T_57229_row0_col1, #T_57229_row0_col2, #T_57229_row0_col3, #T_57229_row0_col4, #T_57229_row0_col5, #T_57229_row0_col6 {\n",
2023-12-12 15:22:01 +01:00
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
2024-01-12 20:39:14 +01:00
"<table id=\"T_57229\">\n",
2023-12-12 15:22:01 +01:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-01-12 20:39:14 +01:00
" <th id=\"T_57229_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
" <th id=\"T_57229_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
" <th id=\"T_57229_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
" <th id=\"T_57229_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
" <th id=\"T_57229_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
" <th id=\"T_57229_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
" <th id=\"T_57229_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-01-12 20:39:14 +01:00
" <th id=\"T_57229_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_57229_row0_col0\" class=\"data row0 col0\" >0.10305</td>\n",
" <td id=\"T_57229_row0_col1\" class=\"data row0 col1\" >0.32102</td>\n",
" <td id=\"T_57229_row0_col2\" class=\"data row0 col2\" >0.00230</td>\n",
" <td id=\"T_57229_row0_col3\" class=\"data row0 col3\" >0.10497</td>\n",
" <td id=\"T_57229_row0_col4\" class=\"data row0 col4\" >0.32399</td>\n",
" <td id=\"T_57229_row0_col5\" class=\"data row0 col5\" >0.00009</td>\n",
" <td id=\"T_57229_row0_col6\" class=\"data row0 col6\" >15.20037</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-01-12 20:39:14 +01:00
"<pandas.io.formats.style.Styler at 0x28e464b10>"
2023-12-12 15:22:01 +01:00
]
},
2024-01-12 20:39:14 +01:00
"execution_count": 31,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training MAE': mae_train, \n",
" 'Training RMSE': rmse_train,\n",
" 'Training R2':r2_train,\n",
" 'Testing MAE': mae_test, \n",
" 'Testing RMSE': rmse_test,\n",
" 'Testing R2':r2_test,\n",
" 'Training Time (mins)': xgb_training_time/60})\n",
2023-12-12 15:22:01 +01:00
"summary_df.set_index('Model Name', inplace=True)\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Displaying summary of results\n",
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
2023-12-12 15:22:01 +01:00
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keeping the xgboost model"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-12 20:39:14 +01:00
"execution_count": 32,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"dump(best_xgb_model, 'xgboost.joblib') \n",
2023-12-12 15:22:01 +01:00
"\n",
"# Load the model\n",
"model = load('xgboost.joblib')"
2023-12-12 15:22:01 +01:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-12 20:39:14 +01:00
"version": "3.11.6"
2023-12-12 15:22:01 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}