2023-12-12 15:22:01 +01:00
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Imports"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 7,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import pandas as pd\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"import seaborn as sns\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
|
|
|
|
|
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
|
|
|
|
|
"import xgboost\n",
|
|
|
|
|
"import time\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"from joblib import dump, load\n",
|
|
|
|
|
"import os"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Load the data"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 14,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"df = pd.read_csv('final_data.csv')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 15,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
|
|
|
|
|
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n",
|
|
|
|
|
" 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n",
|
|
|
|
|
" 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n",
|
|
|
|
|
" 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n",
|
|
|
|
|
" 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n",
|
|
|
|
|
" 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n",
|
|
|
|
|
" 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n",
|
|
|
|
|
" 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n",
|
|
|
|
|
" 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n",
|
|
|
|
|
" 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n",
|
|
|
|
|
" 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n",
|
|
|
|
|
" 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n",
|
|
|
|
|
" 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n",
|
|
|
|
|
" 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n",
|
|
|
|
|
" 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n",
|
|
|
|
|
" 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n",
|
|
|
|
|
" 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n",
|
|
|
|
|
" 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n",
|
|
|
|
|
" 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n",
|
|
|
|
|
" 'y_player_teammate_Goalkeeper'],\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" dtype='object')"
|
|
|
|
|
]
|
|
|
|
|
},
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 15,
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"df.columns"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 16,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<div>\n",
|
|
|
|
|
"<style scoped>\n",
|
|
|
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
|
|
|
" vertical-align: middle;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe tbody tr th {\n",
|
|
|
|
|
" vertical-align: top;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe thead th {\n",
|
|
|
|
|
" text-align: right;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"</style>\n",
|
|
|
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr style=\"text-align: right;\">\n",
|
|
|
|
|
" <th></th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <th>minute</th>\n",
|
|
|
|
|
" <th>position_name</th>\n",
|
|
|
|
|
" <th>shot_body_part_name</th>\n",
|
|
|
|
|
" <th>shot_technique_name</th>\n",
|
|
|
|
|
" <th>shot_type_name</th>\n",
|
|
|
|
|
" <th>shot_first_time</th>\n",
|
|
|
|
|
" <th>shot_one_on_one</th>\n",
|
|
|
|
|
" <th>shot_aerial_won</th>\n",
|
|
|
|
|
" <th>shot_open_goal</th>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <th>shot_follows_dribble</th>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <th>...</th>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <th>y_player_teammate_5</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <th>y_player_teammate_6</th>\n",
|
|
|
|
|
" <th>y_player_teammate_7</th>\n",
|
|
|
|
|
" <th>y_player_teammate_8</th>\n",
|
|
|
|
|
" <th>y_player_teammate_9</th>\n",
|
|
|
|
|
" <th>y_player_teammate_10</th>\n",
|
|
|
|
|
" <th>x_player_opponent_7</th>\n",
|
|
|
|
|
" <th>y_player_opponent_7</th>\n",
|
|
|
|
|
" <th>x_player_teammate_Goalkeeper</th>\n",
|
|
|
|
|
" <th>y_player_teammate_Goalkeeper</th>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>0</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>0</td>\n",
|
|
|
|
|
" <td>Right Center Forward</td>\n",
|
|
|
|
|
" <td>Right Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Right Center Forward</td>\n",
|
|
|
|
|
" <td>Left Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>2</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Center Midfield</td>\n",
|
|
|
|
|
" <td>Right Foot</td>\n",
|
|
|
|
|
" <td>Half Volley</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>True</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <td>48.9</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>3</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Left Center Midfield</td>\n",
|
|
|
|
|
" <td>Right Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>4</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Right Center Back</td>\n",
|
|
|
|
|
" <td>Left Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>True</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"<p>5 rows × 64 columns</p>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"</div>"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" minute position_name shot_body_part_name shot_technique_name \\\n",
|
|
|
|
|
"0 0 Right Center Forward Right Foot Normal \n",
|
|
|
|
|
"1 5 Right Center Forward Left Foot Normal \n",
|
|
|
|
|
"2 5 Center Midfield Right Foot Half Volley \n",
|
|
|
|
|
"3 5 Left Center Midfield Right Foot Normal \n",
|
|
|
|
|
"4 5 Right Center Back Left Foot Normal \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
|
|
|
|
|
"0 Open Play False False False \n",
|
|
|
|
|
"1 Open Play False False False \n",
|
|
|
|
|
"2 Open Play True False False \n",
|
|
|
|
|
"3 Open Play False False False \n",
|
|
|
|
|
"4 Open Play True False False \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
|
|
|
|
|
"0 False False ... NaN \n",
|
|
|
|
|
"1 False False ... NaN \n",
|
|
|
|
|
"2 False False ... 48.9 \n",
|
|
|
|
|
"3 False False ... NaN \n",
|
|
|
|
|
"4 False False ... NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"0 NaN NaN NaN \n",
|
|
|
|
|
"1 NaN NaN NaN \n",
|
|
|
|
|
"2 NaN NaN NaN \n",
|
|
|
|
|
"3 NaN NaN NaN \n",
|
|
|
|
|
"4 NaN NaN NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
|
|
|
|
|
"0 NaN NaN NaN \n",
|
|
|
|
|
"1 NaN NaN NaN \n",
|
|
|
|
|
"2 NaN NaN NaN \n",
|
|
|
|
|
"3 NaN NaN NaN \n",
|
|
|
|
|
"4 NaN NaN NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
|
|
|
|
|
"0 NaN NaN \n",
|
|
|
|
|
"1 NaN NaN \n",
|
|
|
|
|
"2 NaN NaN \n",
|
|
|
|
|
"3 NaN NaN \n",
|
|
|
|
|
"4 NaN NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" y_player_teammate_Goalkeeper \n",
|
|
|
|
|
"0 NaN \n",
|
|
|
|
|
"1 NaN \n",
|
|
|
|
|
"2 NaN \n",
|
|
|
|
|
"3 NaN \n",
|
|
|
|
|
"4 NaN \n",
|
|
|
|
|
"\n",
|
|
|
|
|
"[5 rows x 64 columns]"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 16,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"df.head()"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Data preparation"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 17,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"outputs": [],
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Change the type of categorical features to 'category' \n",
|
|
|
|
|
"df[['position_name', \n",
|
|
|
|
|
" 'shot_technique_name', \n",
|
|
|
|
|
" 'shot_type_name', \n",
|
|
|
|
|
" 'number_of_players_opponents', \n",
|
|
|
|
|
" 'number_of_players_teammates', \n",
|
|
|
|
|
" 'shot_body_part_name']] = df[['position_name', \n",
|
|
|
|
|
" 'shot_technique_name', \n",
|
|
|
|
|
" 'shot_type_name', \n",
|
|
|
|
|
" 'number_of_players_opponents', \n",
|
|
|
|
|
" 'number_of_players_teammates', \n",
|
|
|
|
|
" 'shot_body_part_name']].astype('category')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 18,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Splitting the dataset into features (X) and the target variable (y)\n",
|
|
|
|
|
"y = pd.DataFrame(df['is_goal'])\n",
|
|
|
|
|
"X = df.drop(['is_goal'], axis=1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Splitting the data into a training set and a test set\n",
|
|
|
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Create cross-validation \n",
|
|
|
|
|
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 19,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"name": "stdout",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"Shots attempted in the training set: 27085\n",
|
|
|
|
|
"Goals scored in the training set: 3588\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"count_class_0, count_class_1 = y_train.value_counts()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Display the count of shots attempted in the training set\n",
|
|
|
|
|
"print('Shots attempted in the training set:', count_class_0)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Display the count of successful goals in the training set\n",
|
|
|
|
|
"print('Goals scored in the training set:', count_class_1)"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 20,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" Class imbalance in training data: 7.549\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Class imbalance in training data\n",
|
|
|
|
|
"scale_pos_weight = count_class_0 / count_class_1\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Training XGBoost model "
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 21,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Define the xgboost model\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 22,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Defining the hyper-parameter grid for XG Boost\n",
|
|
|
|
|
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
|
|
|
|
|
" 'max_depth': [3, 5, 7, 8, 9],\n",
|
|
|
|
|
" 'n_estimators': [100, 150, 200, 250, 300],\n",
|
|
|
|
|
" 'scale_pos_weight': [1, scale_pos_weight]}"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 23,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Starting the timer\n",
|
|
|
|
|
"start_time = time.time()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Perform grid search with cross-validation\n",
|
|
|
|
|
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
|
|
|
|
"# Fit the best model on the entire training set\n",
|
|
|
|
|
"grid_xg.fit(X_train, y_train)\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"\n",
|
|
|
|
|
"# Take the best parameters for xgboost model\n",
|
|
|
|
|
"best_xgb_model = grid_xg.best_estimator_\n",
|
|
|
|
|
"\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"# Stopping the timer\n",
|
|
|
|
|
"stop_time = time.time()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Training Time\n",
|
|
|
|
|
"xgb_training_time = stop_time - start_time"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 24,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"Best parameters: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
|
|
|
|
|
"Model Training Time: 912.022 seconds\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Print the best parameters and training time\n",
|
|
|
|
|
"print(\"Best parameters: \", grid_xg.best_params_)\n",
|
|
|
|
|
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Model evaluation"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Training set"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 25,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKbklEQVR4nO3deXwN9/7H8fdJyBEkISWJWINaUluLRiiqVKwVdFFdQtHSaEtQ1asEbXOp1r50FV20qi1VWppSclV0SZtaimutWhKxRCRIIpnfH37OdcSSOGfk4PXsY+51Zr5n5juThHc+3+/MsRiGYQgAAMCFuRV1BwAAAK6GwAIAAFwegQUAALg8AgsAAHB5BBYAAODyCCwAAMDlEVgAAIDLI7AAAACXR2ABAAAuj8CCG8aOHTvUvn17+fj4yGKxaMmSJU7d/969e2WxWBQbG+vU/d7I7r33Xt17771F3Q1T8PUGbiwEFhTKrl279Mwzz6h69eoqUaKEvL291aJFC02bNk2nT5829dgRERHatGmTXnvtNX300Udq0qSJqce7nvr06SOLxSJvb+9LXscdO3bIYrHIYrFo8uTJhd7/wYMHFR0draSkJCf01lzR0dG2c73S4opBKi8vTx9++KFCQkLk6+srLy8v1apVS08++aQ2bNhQ6P2dOnVK0dHRWrNmjfM7C9xgihV1B3DjWL58uR566CFZrVY9+eSTqlevnrKzs7Vu3TqNGDFCW7Zs0TvvvGPKsU+fPq2EhAT961//0uDBg005RtWqVXX69GkVL17clP1fTbFixXTq1Cl98803evjhh+22ffLJJypRooTOnDlzTfs+ePCgxo0bp2rVqqlRo0YFft/3339/TcdzRI8ePVSzZk3b64yMDA0aNEjdu3dXjx49bOv9/f0dOo4ZX+/nn39es2bNUrdu3fTYY4+pWLFi2r59u7777jtVr15dzZo1K9T+Tp06pXHjxkmSSwY04HoisKBA9uzZo169eqlq1apavXq1KlSoYNsWGRmpnTt3avny5aYdPzU1VZJUpkwZ045hsVhUokQJ0/Z/NVarVS1atNCnn36aL7AsWLBAnTt31pdffnld+nLq1CmVLFlSHh4e1+V4F2rQoIEaNGhge33kyBENGjRIDRo00OOPP37Z9505c0YeHh5ycytY4djZX++UlBTNnj1bAwYMyBfcp06davseBnBtGBJCgUyaNEkZGRl6//337cLKeTVr1tQLL7xge3327FlNmDBBNWrUkNVqVbVq1fTyyy8rKyvL7n3VqlVTly5dtG7dOt19990qUaKEqlevrg8//NDWJjo6WlWrVpUkjRgxQhaLRdWqVZN0bijl/J8vdH5Y4UJxcXG65557VKZMGZUuXVq1a9fWyy+/bNt+uTkNq1evVsuWLVWqVCmVKVNG3bp109atWy95vJ07d6pPnz4qU6aMfHx81LdvX506deryF/YivXv31nfffae0tDTbul9//VU7duxQ796987U/duyYhg8frvr166t06dLy9vZWx44d9eeff9rarFmzRk2bNpUk9e3b1zakcv487733XtWrV0+JiYlq1aqVSpYsabsuF89hiYiIUIkSJfKdf1hYmMqWLauDBw8W+FwdsWbNGlksFn322WcaPXq0KlasqJIlSyo9Pb1A10S69Ne7T58+Kl26tA4cOKDw8HCVLl1a5cuX1/Dhw5Wbm3vFPu3Zs0eGYahFixb5tlksFvn5+dmtS0tL05AhQ1S5cmVZrVbVrFlTEydOVF5enq1/5cuXlySNGzfO9nWLjo6+hisG3PiosKBAvvnmG1WvXl3NmzcvUPv+/ftr/vz5evDBBzVs2DD9/PPPiomJ0datW7V48WK7tjt37tSDDz6ofv36KSIiQh988IH69Omjxo0b64477lCPHj1UpkwZDR06VI8++qg6deqk0qVLF6r/W7ZsUZcuXdSgQQONHz9eVqtVO3fu1E8//XTF9/3www/q2LGjqlevrujoaJ0+fVozZsxQixYt9Pvvv+cLSw8//LCCgoIUExOj33//Xe+99578/Pw0ceLEAvWzR48eGjhwoL766is99dRTks5VV+rUqaO77rorX/vdu3dryZIleuihhxQUFKSUlBS9/fbbat26tf766y8FBgaqbt26Gj9+vMaMGaOnn35aLVu2lCS7r+XRo0fVsWNH9erVS48//vhlh1umTZum1atXKyIiQgkJCXJ3d9fbb7+t77//Xh999JECAwMLdJ7OMmHCBHl4eGj48OHKysqSh4eH/vrrr6tekyvJzc1VWFiYQkJCNHnyZP3www968803VaNGDQ0aNOiy7zsfqhctWqSHHnpIJUuWvGzbU6dOqXXr1jpw4ICeeeYZValSRevXr9eoUaN06NAhTZ06VeXLl9ecOXPyDYddWH0CbikGcBUnTpwwJBndunUrUPukpCRDktG/f3+79cOHDzckGatXr7atq1q1qiHJiI+Pt607fPiwYbVajWHDhtnW7dmzx5BkvPHGG3b7jIiIMKpWrZqvD2PHjjUu/PaeMmWKIclITU29bL/PH2PevHm2dY0aNTL8/PyMo0eP2tb9+eefhpubm/Hkk0/mO95TTz1lt8/u3bsbt91222WPeeF5lCpVyjAMw3jwwQeNtm3bGoZhGLm5uUZAQIAxbty4S16DM2fOGLm5ufnOw2q1GuPHj7et+/XXX/Od23mtW7c2JBlz58695LbWrVvbrVu5cqUhyXj11VeN3bt3G6VLlzbCw8Oveo7XKjU11ZBkjB071rbuxx9/NCQZ1atXN06dOmXXvqDX5FJf74iICEOSXTvDMIw777zTaNy48VX7+uSTTxqSjLJlyxrdu3c3Jk+ebGzdujVfuwkTJhilSpUy/vvf/9qtf+mllwx3d3dj3759lz134FbFkBCuKj09XZLk5eVVoPbffvutJCkqKspu/bBhwyQp31yX4OBg22/9klS+fHnVrl1bu3fvvuY+X+z83Jevv/7aVnK/mkOHDikpKUl9+vSRr6+vbX2DBg10//33287zQgMHDrR73bJlSx09etR2DQuid+/eWrNmjZKTk7V69WolJydfcjhIOjfv5fycjdzcXB09etQ23PX7778X+JhWq1V9+/YtUNv27dvrmWee0fjx49WjRw+VKFFCb7/9doGP5UwRERHy9PS0W+eMa3Kpr2NBvh/nzZunmTNnKigoSIsXL9bw4cNVt25dtW3bVgcOHLC1W7RokVq2bKmyZcvqyJEjtqVdu3bKzc1VfHx8gfoJ3EoILLgqb29vSdLJkycL1P7vv/+Wm5ub3Z0ekhQQEKAyZcro77//tltfpUqVfPsoW7asjh8/fo09zu+RRx5RixYt1L9/f/n7+6tXr176/PPPrxhezvezdu3a+bbVrVtXR44cUWZmpt36i8+lbNmyklSoc+nUqZO8vLy0cOFCffLJJ2ratGm+a3leXl6epkyZottvv11Wq1XlypVT+fLltXHjRp04caLAx6xYsWKhJthOnjxZvr6+SkpK0vTp0/PNz7iU1NRUJScn25aMjIwCH+9ygoKC8q1z9JqUKFHCNnfkvIJ+P7q5uSkyMlKJiYk6cuSIvv76a3Xs2FGrV69Wr169bO127NihFStWqHz58nZLu3btJEmHDx++6rGAWw2BBVfl7e2twMBAbd68uVDvu3jS6+W4u7tfcr1hGNd8jIsnSHp6eio+Pl4//PCDnnjiCW3cuFGPPPKI7r///qtOpiwMR87lPKvVqh49emj+/PlavHjxZasrkvT6668rKipKrVq10scff6yVK1cqLi5Od9xxR4ErSZLyVSmu5o8//rD9o7pp06YCvadp06aqUKGCbbmW58lc7FL9dvSaXO5rWFi33XabHnjgAX377bdq3bq11q1bZwvBeXl5uv/++xUXF3fJpWfPnk7pA3AzYdItCqRLly565513lJCQoNDQ0Cu2rVq1qvLy8rRjxw7VrVvXtj4lJUVpaWm2yYnOULZsWbs7as67uIojnfvtt23btmrbtq3eeustvf766/rXv/6lH3/80fab7cXnIUnbt2/Pt23btm0qV66cSpUq5fh
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 2 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Evaluate the model on training set\n",
|
|
|
|
|
"y_pred_train = best_xgb_model.predict(X_train)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Confusion Matrix for Training Data\n",
|
|
|
|
|
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
|
|
|
|
|
"ax.set_xlabel('Predicted')\n",
|
|
|
|
|
"ax.set_ylabel('Actual')\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"ax.set_title('Confusion Matrix - Train Set')\n",
|
|
|
|
|
"plt.show()"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Test set"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 26,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAHHCAYAAACPy0PBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJnElEQVR4nO3deVxU1f/H8fegMCKrqICkIuZKbomm5L4kGpprZVniVukXLXelxbXEbDHNLVvUSkuzNNNy+WpqJZpRlLnlVlQKbiGJCij394c/5+sEOmBzHcXXs8d9PJx7z9x77uTy5nPOuWMxDMMQAACAC7m5ugMAAAAEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEgAA4HIEEty09u3bpzZt2sjPz08Wi0XLly936vl//fVXWSwWzZ8/36nnvZk1b95czZs3d3U3ABRCBBL8KwcOHNATTzyhihUrqlixYvL19VWjRo00bdo0nT171tRrx8TEaMeOHXrhhRf03nvvqV69eqZe73rq1auXLBaLfH198/wc9+3bJ4vFIovFopdffrnA5z98+LDGjRunpKQkJ/TWXOPGjbPd69U2ZwWlzz//XOPGjct3+5ycHL377rtq0KCBAgIC5OPjoypVqqhnz57aunVrga9/5swZjRs3Ths3bizwe4GbWVFXdwA3r1WrVun++++X1WpVz549VaNGDWVlZenrr7/WiBEjtHPnTs2dO9eUa589e1YJCQl65plnNHDgQFOuERoaqrNnz8rd3d2U8ztStGhRnTlzRp999pkeeOABu2MLFy5UsWLFdO7cuWs69+HDhzV+/HhVqFBBderUyff71q5de03X+ze6dOmiSpUq2V6fPn1aAwYMUOfOndWlSxfb/qCgIKdc7/PPP9fMmTPzHUqefPJJzZw5Ux07dlSPHj1UtGhR7d27V1988YUqVqyohg0bFuj6Z86c0fjx4yWJahRuKQQSXJNDhw6pe/fuCg0N1YYNG1SmTBnbsdjYWO3fv1+rVq0y7frHjh2TJPn7+5t2DYvFomLFipl2fkesVqsaNWqkDz74IFcgWbRokaKjo/Xxxx9fl76cOXNGxYsXl4eHx3W53uVq1aqlWrVq2V4fP35cAwYMUK1atfTII49c9/5cLjU1VbNmzdJjjz2WK3y/9tprtt+nABxjyAbXZMqUKTp9+rTefvttuzBySaVKlfTUU0/ZXp8/f14TJ07U7bffLqvVqgoVKujpp59WZmam3fsqVKig9u3b6+uvv9Zdd92lYsWKqWLFinr33XdtbcaNG6fQ0FBJ0ogRI2SxWFShQgVJF4c6Lv36cpfK/pdbt26dGjduLH9/f3l7e6tq1ap6+umnbcevNIdkw4YNatKkiby8vOTv76+OHTtq9+7deV5v//796tWrl/z9/eXn56fevXvrzJkzV/5g/+Hhhx/WF198obS0NNu+7du3a9++fXr44YdztT958qSGDx+umjVrytvbW76+vmrXrp1+/PFHW5uNGzeqfv36kqTevXvbhjwu3Wfz5s1Vo0YNJSYmqmnTpipevLjtc/nnHJKYmBgVK1Ys1/1HRUWpRIkSOnz4cL7v9d/as2ePunXrpoCAABUrVkz16tXTihUr7NpkZ2dr/Pjxqly5sooVK6aSJUuqcePGWrdunaSLv39mzpwpSXbDQVdy6NAhGYahRo0a5TpmsVgUGBhoty8tLU2DBw9WuXLlZLVaValSJb344ovKycmRdPH3XOnSpSVJ48ePt12/IENIwM2KCgmuyWeffaaKFSvq7rvvzlf7fv36acGCBerWrZuGDRumbdu2KT4+Xrt379ayZcvs2u7fv1/dunVT3759FRMTo3feeUe9evVSRESE7rjjDnXp0kX+/v4aMmSIHnroId17773y9vYuUP937typ9u3bq1atWpowYYKsVqv279+vb7755qrv++9//6t27dqpYsWKGjdunM6ePavXX39djRo10vfff58rDD3wwAMKCwtTfHy8vv/+e7311lsKDAzUiy++mK9+dunSRf3799cnn3yiPn36SLpYHalWrZrq1q2bq/3Bgwe1fPly3X///QoLC1NqaqreeOMNNWvWTLt27VJISIiqV6+uCRMmaMyYMXr88cfVpEkTSbL7f3nixAm1a9dO3bt31yOPPHLF4ZBp06Zpw4YNiomJUUJCgooUKaI33nhDa9eu1XvvvaeQkJB83ee/tXPnTjVq1Ei33XabRo8eLS8vLy1ZskSdOnXSxx9/rM6dO0u6GBTj4+PVr18/3XXXXUpPT9d3332n77//Xvfcc4+eeOIJHT58WOvWrdN7773n8LqXgvFHH32k+++/X8WLF79i2zNnzqhZs2b6888/9cQTT6h8+fLasmWL4uLidOTIEb322msqXbq0Zs+enWtI6vIKEVBoGUABnTp1ypBkdOzYMV/tk5KSDElGv3797PYPHz7ckGRs2LDBti80NNSQZGzevNm27+jRo4bVajWGDRtm23fo0CFDkvHSSy/ZnTMmJsYIDQ3N1YexY8cal/92nzp1qiHJOHbs2BX7feka8+bNs+2rU6eOERgYaJw4ccK278cffzTc3NyMnj175rpenz597M7ZuXNno2TJkle85uX34eXlZRiGYXTr1s1o1aqVYRiGceHCBSM4ONgYP358np/BuXPnjAsXLuS6D6vVakyYMMG2b/v27bnu7ZJmzZoZkow5c+bkeaxZs2Z2+9asWWNIMp5//nnj4MGDhre3t9GpUyeH93itjh07Zkgyxo4da9vXqlUro2bNmsa5c+ds+3Jycoy7777bqFy5sm1f7dq1jejo6KuePzY21ijIX409e/Y0JBklSpQwOnfubLz88svG7t27c7WbOHGi4eXlZfzyyy92+0ePHm0UKVLESE5OvuL9AbcChmxQYOnp6ZIkHx+ffLX//PPPJUlDhw612z9s2DBJyjXXJDw83PZTuySVLl1aVatW1cGDB6+5z/90ae7Jp59+aiuXO3LkyBElJSWpV69eCggIsO2vVauW7rnnHtt9Xq5///52r5s0aaITJ07YPsP8ePjhh7Vx40alpKRow4YNSklJyXO4Rro478TN7eIf6wsXLujEiRO24ajvv/8+39e0Wq3q3bt3vtq2adNGTzzxhCZMmKAuXbqoWLFieuONN/J9rX/r5MmT2rBhgx544AH9/fffOn78uI4fP64TJ04oKipK+/bt059//inp4v/3nTt3at++fU67/rx58zRjxgyFhYVp2bJlGj58uKpXr65WrVrZritdrKI0adJEJUqUsPXx+PHjat26tS5cuKDNmzc7rU/AzYhAggLz9fWVJP3999/5av/bb7/Jzc3NbqWEJAUHB8vf31+//fab3f7y5cvnOkeJEiX0119/XWOPc3vwwQfVqFEj9evXT0FBQerevbuWLFly1XByqZ9Vq1bNdax69eo6fvy4MjIy7Pb/815KlCghSQW6l3vvvVc+Pj5avHixFi5cqPr16+f6LC/JycnR1KlTVblyZVmtVpUqVUqlS5fWTz/9pFOnTuX7mrfddluBJrC+/PLLCggIUFJSkqZPn55r7kRejh07ppSUFNt2+vTpfF/vcvv375dhGHruuedUunRpu23s2LGSpKNHj0qSJkyYoLS0NFWpUkU1a9bUiBEj9NNPP13TdS9xc3NTbGysEhMTdfz4cX366adq166dNmzYoO7du9va7du3T6tXr87Vx9atW9v1EbhVMYcEBebr66uQkBD9/PPPBXrf1SYHXq5IkSJ57jcM45qvceHCBbvXnp6e2rx5s7788kutWrVKq1ev1uLFi9WyZUutXbv2in0oqH9zL5dYrVZ16dJFCxYs0MGDB686wXHSpEl67rnn1KdPH02cOFEBAQFyc3PT4MGD810Jki5+PgXxww8/2P5B3bFjhx566CGH76lfv75dGB07duw1Td68dF/Dhw9XVFRUnm0uBbimTZvqwIED+vTTT7V27Vq99dZbmjp1qubMmaN+/foV+Nr/VLJkSd13332677771Lx5c23atEm//fabQkNDlZOTo3vuuUc
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 2 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Evaluate the model on test set\n",
|
|
|
|
|
"y_pred_test = best_xgb_model.predict(X_test)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
|
|
|
|
"# Confusion Matrix for Testig Data\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
|
|
|
|
|
"ax.set_xlabel('Predicted')\n",
|
|
|
|
|
"ax.set_ylabel('Actual')\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"ax.set_title('Confusion Matrix - Test Set')\n",
|
|
|
|
|
"plt.show()"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 27,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"The test dataset contains 7669 shots, with 914 of them being goals.\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Number of goals in test set\n",
|
|
|
|
|
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"cell_type": "markdown",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Feature importance"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 28,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAHHCAYAAADONqsSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxP2f8H8Nen/dOqUipaabOVrVTI0qYmlSWyZpIxdgZZ0mbLHrIvhWEsgzQjSyhLkiRLIkpESqRFpfrU5/z+6Nf9dn0+JTNUzHk+Hj3G55xzzz33fe987ufce+65HEIIAUVRFEVRFEVRFNViiDR3AyiKoiiKoiiKoig22lGjKIqiKIqiKIpqYWhHjaIoiqIoiqIoqoWhHTWKoiiKoiiKoqgWhnbUKIqiKIqiKIqiWhjaUaMoiqIoiqIoimphaEeNoiiKoiiKoiiqhaEdNYqiKIqiKIqiqBaGdtQoiqIoiqIoiqJaGNpRoyiKoqjvSHh4ODgcDp4/f97cTaEoiqK+IdpRoyiKolq02o6JsL+FCxd+k3XeuHEDAQEBKCws/Cb1/5eVlZUhICAAsbGxzd0UiqKoFk2suRtAURRFUY0RFBQEXV1dVlrnzp2/ybpu3LiBwMBAeHp6olWrVt9kHf/UuHHjMGrUKEhKSjZ3U/6RsrIyBAYGAgD69+/fvI2hKIpqwWhHjaIoivouDB48GD179mzuZvwrpaWlkJGR+Vd1iIqKQlRU9Cu1qOnw+XxUVlY2dzMoiqK+G3ToI0VRFPVDOHv2LPr27QsZGRnIycnByckJDx8+ZJW5f/8+PD09oaenBykpKaipqeHnn39Gfn4+UyYgIADz588HAOjq6jLDLJ8/f47nz5+Dw+EgPDxcYP0cDgcBAQGsejgcDlJTUzF69GgoKiqiT58+TP7vv/+OHj16gMvlQklJCaNGjcLLly8/u53CnlHT0dHBTz/9hNjYWPTs2RNcLhddunRhhheePHkSXbp0gZSUFHr06IHk5GRWnZ6enpCVlcWzZ89gb28PGRkZaGhoICgoCIQQVtnS0lL89ttv0NTUhKSkJAwNDbFu3TqBchwOB9OnT8ehQ4fQqVMnSEpKYseOHVBRUQEABAYGMrGtjVtj9k/d2KanpzN3PRUUFDBx4kSUlZUJxOz333+HmZkZpKWloaioiH79+uHChQusMo05fiiKopoSvaNGURRFfReKiorw7t07Vlrr1q0BAAcPHsSECRNgb2+P1atXo6ysDNu3b0efPn2QnJwMHR0dAEB0dDSePXuGiRMnQk1NDQ8fPsSuXbvw8OFD3Lx5ExwOB0OHDsWTJ0/wxx9/YOPGjcw6VFRU8Pbt2y9u94gRI6Cvr4+VK1cynZkVK1Zg6dKlcHd3x6RJk/D27Vts2bIF/fr1Q3Jy8j8abpmeno7Ro0fjl19+wdixY7Fu3To4Oztjx44dWLx4MaZOnQoAWLVqFdzd3ZGWlgYRkf9dr62uroaDgwN69+6NNWvW4Ny5c/D390dVVRWCgoIAAIQQDBkyBDExMfDy8oKpqSnOnz+P+fPnIzs7Gxs3bmS16fLlyzh27BimT5+O1q1bw8TEBNu3b8evv/4KNzc3DB06FADQtWtXAI3bP3W5u7tDV1cXq1atwp07d7Bnzx6oqqpi9erVTJnAwEAEBATA0tISQUFBkJCQQEJCAi5fvgw7OzsAjT9+KIqimhShKIqiqBYsLCyMABD6RwghHz58IK1atSLe3t6s5XJzc4mCggIrvaysTKD+P/74gwAgV69eZdLWrl1LAJDMzExW2czMTAKAhIWFCdQDgPj7+zOf/f39CQDi4eHBKvf8+XMiKipKVqxYwUp/8OABERMTE0ivLx5126atrU0AkBs3bjBp58+fJwAIl8slL168YNJ37txJAJCYmBgmbcKECQQAmTFjBpPG5/OJk5MTkZCQIG/fviWEEBIREUEAkOXLl7PaNHz4cMLhcEh6ejorHiIiIuThw4essm/fvhWIVa3G7p/a2P7888+ssm5ubkRZWZn5/PTpUyIiIkLc3NxIdXU1qyyfzyeEfNnxQ1EU1ZTo0EeKoijqu7B161ZER0ez/oCauzCFhYXw8PDAu3fvmD9RUVGYm5sjJiaGqYPL5TL/Li8vx7t379C7d28AwJ07d75Ju6dMmcL6fPLkSfD5fLi7u7Paq6amBn19fVZ7v0THjh1hYWHBfDY3NwcADBw4EFpaWgLpz549E6hj+vTpzL9rhy5WVlbi4sWLAICoqCiIiopi5syZrOV+++03EEJw9uxZVrq1tTU6duzY6G340v3zaWz79u2L/Px8FBcXAwAiIiLA5/Ph5+fHuntYu33Alx0/FEVRTYkOfaQoiqK+C2ZmZkInE3n69CmAmg6JMPLy8sy/379/j8DAQBw5cgR5eXmsckVFRV+xtf/z6UyVT58+BSEE+vr6QsuLi4v/o/XU7YwBgIKCAgBAU1NTaHpBQQErXUREBHp6eqw0AwMDAGCeh3vx4gU0NDQgJyfHKmdsbMzk1/Xptn/Ol+6fT7dZUVERQM22ycvLIyMjAyIiIg12Fr/k+KEoimpKtKNGURRFfdf4fD6AmueM1NTUBPLFxP53qnN3d8eNGzcwf/58mJqaQlZWFnw+Hw4ODkw9Dfn0Gala1dXV9S5T9y5RbXs5HA7Onj0rdPZGWVnZz7ZDmPpmgqwvnXwy+ce38Om2f86X7p+vsW1fcvxQFEU1JfrtQ1EURX3X2rdvDwBQVVWFjY1NveUKCgpw6dIlBAYGws/Pj0mvvaNSV30dsto7Np++CPvTO0mfay8hBLq6uswdq5aAz+fj2bNnrDY9efIEAJjJNLS1tXHx4kV8+PCBdVft8ePHTP7n1BfbL9k/jdW+fXvw+XykpqbC1NS03jLA548fiqKopkafUaMoiqK+a/b29pCXl8fKlSvB4/EE8mtnaqy9+/Lp3ZaQkBCBZWrfdfZph0xeXh6tW7fG1atXWenbtm1rdHuHDh0KUVFRBAYGCrSFECIwFX1TCg0NZbUlNDQU4uLiGDRoEADA0dER1dXVrHIAsHHjRnA4HAwePPiz65CWlgYgGNsv2T+N5erqChEREQQFBQnckatdT2OPH4qiqKZG76hRFEVR3zV5eXls374d48aNQ/fu3TFq1CioqKggKysLZ86cgZWVFUJDQyEvL49+/fphzZo14PF4aNu2LS5cuIDMzEyBOnv06AEAWLJkCUaNGgVxcXE4OztDRkYGkyZNQnBwMCZNmoSePXvi6tWrzJ2nxmjfvj2WL1+ORYsW4fnz53B1dYWcnBwyMzNx6tQpTJ48GfPmzftq8WksKSkpnDt3DhMmTIC5uTnOnj2LM2fOYPHixcy7z5ydnTFgwAAsWbIEz58/h4mJCS5cuIDTp09j9uzZzN2phnC5XHTs2BFHjx6FgYEBlJSU0LlzZ3Tu3LnR+6exOnTogCVLlmDZsmXo27cvhg4dCklJSSQmJkJDQwOrVq1q9PFDURTV1GhHjaIoivrujR49GhoaGggODsbatWtRUVGBtm3bom/fvpg4cSJT7vDhw5gxYwa2bt0KQgjs7Oxw9uxZaGhosOrr1asXli1bhh07duDcuXPg8/nIzMyEjIwM/Pz88PbtW/z55584duwYBg8ejLNnz0JVVbXR7V24cCEMDAywceNGBAYGAqiZ9MPOzg5Dhgz5OkH5QqKiojh37hx+/fVXzJ8/H3JycvD392cNQxQREUFkZCT8/Pxw9OhRhIWFQUdHB2vXrsVvv/3W6HXt2bMHM2bMwJw5c1BZWQl/f3907ty50fvnSwQFBUFXVxdbtmzBkiVLIC0tja5du2LcuHFMmcYePxRFUU2JQ5riaWKKoiiKolosT09P/PnnnygpKWnuplAURVH/jz6jRlEURVEURVEU1cLQjhpFURRFURRFUVQLQztqFEVRFEVRFEVRLQx9Ro2iKIqiKIqiKKqFoXfUKIqiKIqiKIqiWhjaUaMoiqIoiqIoimph6HvUKOo/js/n4/Xr15CTkwOHw2nu5lAURVEU1QiEEHz48AEaGhoQEaH3Xn5EtKNGUf9xr1+/hqamZnM
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Plot feature importance with Gain\n",
|
|
|
|
|
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"plt.show()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 29,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvwAAAHHCAYAAADDDYx8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzde1yP5//A8deng45KJSmLHFIhMYRyCFGa5lwOQ6bMIdYMMaSczxptbDaFMd99tzn8JiatHGKNHIZaiGQOc0qp6Hj//ujR/fXRwSdzKK7n49FjPvd93dd93e/PZ3V9rvu6r7dCkiQJQRAEQRAEQRDeSGqvuwGCIAiCIAiCILw8osMvCIIgCIIgCG8w0eEXBEEQBEEQhDeY6PALgiAIgiAIwhtMdPgFQRAEQRAE4Q0mOvyCIAiCIAiC8AYTHX5BEARBEARBeIOJDr8gCIIgCIIgvMFEh18QBEEQBEEQ3mCiwy8IgiAI1UhERAQKhYLU1NTX3RRBEKoJ0eEXBEEQqrSSDm5ZPzNmzHgp5zx69CjBwcE8ePDgpdT/NsvJySE4OJjY2NjX3RRBeGtovO4GCIIgCIIq5s2bR8OGDZW2tWjR4qWc6+jRo4SEhODj40OtWrVeyjme14gRIxgyZAhaWlqvuynPJScnh5CQEABcXFxeb2ME4S0hOvyCIAhCtdC7d2/atm37upvxr2RnZ6Onp/ev6lBXV0ddXf0FtejVKSoqIi8v73U3QxDeSmJKjyAIgvBG2Lt3L507d0ZPT4+aNWvy3nvvcf78eaUyf/75Jz4+PjRq1AhtbW3q1q3Lhx9+yL179+QywcHBTJs2DYCGDRvK04dSU1NJTU1FoVAQERFR6vwKhYLg4GClehQKBYmJiQwbNgwjIyM6deok7//uu+9o06YNOjo6GBsbM2TIEK5du/bM6yxrDr+VlRV9+vQhNjaWtm3boqOjg729vTxt5ueff8be3h5tbW3atGnDqVOnlOr08fFBX1+fy5cv4+bmhp6eHhYWFsybNw9JkpTKZmdn8+mnn2JpaYmWlhY2NjasWLGiVDmFQoG/vz9bt26lefPmaGlpsX79ekxNTQEICQmRY1sSN1Xenydje+nSJfkujKGhIaNHjyYnJ6dUzL777jscHR3R1dXFyMiILl26sH//fqUyqnx+BKG6EiP8giAIQrWQkZHB3bt3lbbVrl0bgC1btjBq1Cjc3NxYunQpOTk5rFu3jk6dOnHq1CmsrKwAiIqK4vLly4wePZq6dety/vx5vv76a86fP8/vv/+OQqFgwIABXLhwge+//57Vq1fL5zA1NeXOnTuVbvfgwYOxtrZm0aJFcqd44cKFzJkzBy8vL3x9fblz5w5r166lS5cunDp16rmmEV26dIlhw4bx0Ucf8cEHH7BixQo8PT1Zv349n332GRMmTABg8eLFeHl5kZycjJra/8b9CgsLcXd3p0OHDixbtox9+/Yxd+5cCgoKmDdvHgCSJPH+++8TExPDmDFjaNWqFb/++ivTpk3j+vXrrF69WqlNv/32Gz/88AP+/v7Url0bBwcH1q1bx/jx4+nfvz8DBgwAoGXLloBq78+TvLy8aNiwIYsXL+bkyZN888031KlTh6VLl8plQkJCCA4OxsnJiXnz5lGjRg3i4+P57bff6NWrF6D650cQqi1JEARBEKqw8PBwCSjzR5Ik6eHDh1KtWrUkPz8/peNu3bolGRoaKm3PyckpVf/3338vAdKhQ4fkbcuXL5cA6cqVK0plr1y5IgFSeHh4qXoAae7cufLruXPnSoA0dOhQpXKpqamSurq6tHDhQqXtZ8+elTQ0NEptLy8eT7atQYMGEiAdPXpU3vbrr79KgKSjoyNdvXpV3v7VV19JgBQTEyNvGzVqlARIkyZNkrcVFRVJ7733nlSjRg3pzp07kiRJ0s6dOyVAWrBggVKbBg0aJCkUCunSpUtK8VBTU5POnz+vVPbOnTulYlVC1fenJLYffvihUtn+/ftLJiYm8uuLFy9KampqUv/+/aXCwkKlskVFRZIkVe7zIwjVlZjSIwiCIFQLX3zxBVFRUUo/UDwq/ODBA4YOHcrdu3flH3V1ddq3b09MTIxch46Ojvzvx48fc/fuXTp06ADAyZMnX0q7x40bp/T6559/pqioCC8vL6X21q1bF2tra6X2VkazZs3o2LGj/Lp9+/YAdO/enfr165fafvny5VJ1+Pv7y/8umZKTl5fHgQMHAIiMjERdXZ3JkycrHffpp58iSRJ79+5V2t61a1eaNWum8jVU9v15OradO3fm3r17ZGZmArBz506KiooICgpSuptRcn1Quc+PIFRXYkqPIAiCUC04OjqW+dDuxYsXgeKObVkMDAzkf9+/f5+QkBC2b9/O7du3lcplZGS8wNb+z9MrC128eBFJkrC2ti6zvKam5nOd58lOPYChoSEAlpaWZW5PT09X2q6mpkajRo2UtjVt2hRAfl7g6tWrWFhYULNmTaVydnZ28v4nPX3tz1LZ9+fpazYyMgKKr83AwICUlBTU1NQq/NJRmc+PIFRXosMvCIIgVGtFRUVA8TzsunXrltqvofG/P3VeXl4cPXqUadOm0apVK/T19SkqKsLd3V2upyJPzyEvUVhYWO4xT45al7RXoVCwd+/eMlfb0dfXf2Y7ylLeyj3lbZeeesj2ZXj62p+lsu/Pi7i2ynx+BKG6Ep9iQRAEoVpr3LgxAHXq1MHV1bXccunp6URHRxMSEkJQUJC8vWSE90nldexLRpCfTsj19Mj2s9orSRINGzaUR9CrgqKiIi5fvqzUpgsXLgDID602aNCAAwcO8PDhQ6VR/r/++kve/yzlxbYy74+qGjduTFFREYmJibRq1arcMvDsz48gVGdiDr8gCIJQrbm5uWFgYMCiRYvIz88vtb9kZZ2S0eCnR39DQ0NLHVOyVv7THXsDAwNq167NoUOHlLZ/+eWXKrd3wIABqKurExISUqotkiSVWoLyVQoLC1NqS1hYGJqamvTo0QMADw8PCgsLlcoBrF69GoVCQe/evZ95Dl1dXaB0bCvz/qiqX79+qKmpMW/evFJ3CErOo+rnRxCqMzHCLwiCIFRrBgYGrFu3jhEjRvDuu+8yZMgQTE1NSUtLY8+ePTg7OxMWFoaBgQFdunRh2bJl5OfnU69ePfbv38+VK1dK1dmmTRsAZs2axZAhQ9DU1MTT0xM9PT18fX1ZsmQJvr6+tG3blkOHDskj4apo3LgxCxYsYObMmaSmptKvXz9q1qzJlStX2LFjB2PHjmXq1KkvLD6q0tbWZt++fYwaNYr27duzd+9e9uzZw2effSavne/p6Um3bt2YNWsWqampODg4sH//fnbt2kVAQIA8Wl4RHR0dmjVrxn/+8x+aNm2KsbExLVq0oEWLFiq/P6pq0qQJs2bNYv78+XTu3JkBAwagpaXF8ePHsbCwYPHixSp/fgShWntNqwMJgiAIgkpKlqE8fvx4heViYmIkNzc3ydDQUNLW1pYaN24s+fj4SCdOnJDL/P3331L//v2lWrVqSYaGhtLgwYOlGzdulLlM5Pz586V69epJampqSstg5uTkSGPGjJEMDQ2lmjVrSl5eXtLt27fLXZazZEnLp/30009Sp06dJD09PUlPT0+ytbWVJk6cKCUnJ6sUj6eX5XzvvfdKlQWkiRMnKm0rWVp0+fLl8rZRo0ZJenp6UkpKitSrVy9JV1dXMjMzk+bOnVtqOcuHDx9Kn3zyiWRhYSFpampK1tbW0vLly+VlLis6d4mjR49Kbdq0kWrUqKEUN1Xfn/JiW1ZsJEmSNm7cKLVu3VrS0tKSjIyMpK5du0pRUVFKZVT5/AhCdaWQpFfw1I4gCIIgCFWWj48PP/74I1lZWa+7KYIgvARiDr8gCIIgCIIgvMFEh18QBEEQBEEQ3mCiwy8IgiAIgiAIbzAxh18QBEEQBEEQ3mBihF8QBEEQBEEQ3mCiwy8IgiAIgiAIbzCReEsQ3nJFRUXcuHGDmjVrlpvyXhAEQRCEqkWSJB4+fIiFhQVqahWP4YsOvyC85W7cuIGlpeXrboYgCIIgCM/h2rVrvPP
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Plot feature importance with Weight\n",
|
|
|
|
|
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"plt.show()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Summary"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 30,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Calculating MAE, RMSE and R2 for training and test sets \n",
|
|
|
|
|
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
|
|
|
|
|
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
|
|
|
|
|
"r2_train = r2_score(y_train, y_pred_train)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
|
|
|
|
|
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
|
|
|
|
|
"r2_test = r2_score(y_test, y_pred_test)"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 31,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<style type=\"text/css\">\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"#T_57229_row0_col0, #T_57229_row0_col1, #T_57229_row0_col2, #T_57229_row0_col3, #T_57229_row0_col4, #T_57229_row0_col5, #T_57229_row0_col6 {\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" font-weight: bold;\n",
|
|
|
|
|
" border: 2.0px solid grey;\n",
|
|
|
|
|
" color: white;\n",
|
|
|
|
|
"}\n",
|
|
|
|
|
"</style>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"<table id=\"T_57229\">\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th class=\"blank level0\" > </th>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <th id=\"T_57229_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
|
|
|
|
|
" <th id=\"T_57229_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
|
|
|
|
|
" <th id=\"T_57229_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
|
|
|
|
|
" <th id=\"T_57229_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
|
|
|
|
|
" <th id=\"T_57229_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
|
|
|
|
|
" <th id=\"T_57229_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
|
|
|
|
|
" <th id=\"T_57229_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th class=\"index_name level0\" >Model Name</th>\n",
|
|
|
|
|
" <th class=\"blank col0\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col1\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col2\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col3\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col4\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col5\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col6\" > </th>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
" <th id=\"T_57229_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col0\" class=\"data row0 col0\" >0.10305</td>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col1\" class=\"data row0 col1\" >0.32102</td>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col2\" class=\"data row0 col2\" >0.00230</td>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col3\" class=\"data row0 col3\" >0.10497</td>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col4\" class=\"data row0 col4\" >0.32399</td>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col5\" class=\"data row0 col5\" >0.00009</td>\n",
|
|
|
|
|
" <td id=\"T_57229_row0_col6\" class=\"data row0 col6\" >15.20037</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"<pandas.io.formats.style.Styler at 0x28e464b10>"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 31,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Creating of dataframe of summary results\n",
|
|
|
|
|
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" 'Training MAE': mae_train, \n",
|
|
|
|
|
" 'Training RMSE': rmse_train,\n",
|
|
|
|
|
" 'Training R2':r2_train,\n",
|
|
|
|
|
" 'Testing MAE': mae_test, \n",
|
|
|
|
|
" 'Testing RMSE': rmse_test,\n",
|
|
|
|
|
" 'Testing R2':r2_test,\n",
|
|
|
|
|
" 'Training Time (mins)': xgb_training_time/60})\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"summary_df.set_index('Model Name', inplace=True)\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"# Displaying summary of results\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" 'border': '2.0px solid grey','color': 'white'})"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Keeping the xgboost model"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"execution_count": 32,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Save the model\n",
|
|
|
|
|
"dump(best_xgb_model, 'xgboost.joblib') \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Load the model\n",
|
|
|
|
|
"model = load('xgboost.joblib')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
2024-01-12 20:39:14 +01:00
|
|
|
|
"version": "3.11.6"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 2
|
|
|
|
|
}
|