2023-12-12 15:22:01 +01:00
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Imports"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 1,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import pandas as pd\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"import seaborn as sns\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
|
|
|
|
|
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
|
|
|
|
|
"import xgboost\n",
|
|
|
|
|
"import time\n",
|
|
|
|
|
"from joblib import dump, load"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Load the data"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 2,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"df = pd.read_csv('final_data.txt')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 3,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
|
|
|
|
|
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
|
|
|
|
|
" 'shot_aerial_won', 'shot_deflected', 'shot_open_goal',\n",
|
|
|
|
|
" 'shot_follows_dribble', 'shot_redirect', 'x1', 'y1',\n",
|
|
|
|
|
" 'number_of_players_opponents', 'number_of_players_teammates', 'is_goal',\n",
|
|
|
|
|
" 'angle', 'distance', 'x_player_opponent_Goalkeeper',\n",
|
|
|
|
|
" 'x_player_opponent_8', 'x_player_opponent_1', 'x_player_opponent_2',\n",
|
|
|
|
|
" 'x_player_opponent_3', 'x_player_teammate_1', 'x_player_opponent_4',\n",
|
|
|
|
|
" 'x_player_opponent_5', 'x_player_opponent_6', 'x_player_teammate_2',\n",
|
|
|
|
|
" 'x_player_opponent_9', 'x_player_opponent_10', 'x_player_opponent_11',\n",
|
|
|
|
|
" 'x_player_teammate_3', 'x_player_teammate_4', 'x_player_teammate_5',\n",
|
|
|
|
|
" 'x_player_teammate_6', 'x_player_teammate_7', 'x_player_teammate_8',\n",
|
|
|
|
|
" 'x_player_teammate_9', 'x_player_teammate_10',\n",
|
|
|
|
|
" 'y_player_opponent_Goalkeeper', 'y_player_opponent_8',\n",
|
|
|
|
|
" 'y_player_opponent_1', 'y_player_opponent_2', 'y_player_opponent_3',\n",
|
|
|
|
|
" 'y_player_teammate_1', 'y_player_opponent_4', 'y_player_opponent_5',\n",
|
|
|
|
|
" 'y_player_opponent_6', 'y_player_teammate_2', 'y_player_opponent_9',\n",
|
|
|
|
|
" 'y_player_opponent_10', 'y_player_opponent_11', 'y_player_teammate_3',\n",
|
|
|
|
|
" 'y_player_teammate_4', 'y_player_teammate_5', 'y_player_teammate_6',\n",
|
|
|
|
|
" 'y_player_teammate_7', 'y_player_teammate_8', 'y_player_teammate_9',\n",
|
|
|
|
|
" 'y_player_teammate_10', 'x_player_opponent_7', 'y_player_opponent_7',\n",
|
|
|
|
|
" 'x_player_teammate_Goalkeeper', 'y_player_teammate_Goalkeeper',\n",
|
|
|
|
|
" 'shot_kick_off'],\n",
|
|
|
|
|
" dtype='object')"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 3,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"df.columns"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 4,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<div>\n",
|
|
|
|
|
"<style scoped>\n",
|
|
|
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
|
|
|
" vertical-align: middle;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe tbody tr th {\n",
|
|
|
|
|
" vertical-align: top;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe thead th {\n",
|
|
|
|
|
" text-align: right;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"</style>\n",
|
|
|
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr style=\"text-align: right;\">\n",
|
|
|
|
|
" <th></th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <th>minute</th>\n",
|
|
|
|
|
" <th>position_name</th>\n",
|
|
|
|
|
" <th>shot_body_part_name</th>\n",
|
|
|
|
|
" <th>shot_technique_name</th>\n",
|
|
|
|
|
" <th>shot_type_name</th>\n",
|
|
|
|
|
" <th>shot_first_time</th>\n",
|
|
|
|
|
" <th>shot_one_on_one</th>\n",
|
|
|
|
|
" <th>shot_aerial_won</th>\n",
|
|
|
|
|
" <th>shot_deflected</th>\n",
|
|
|
|
|
" <th>shot_open_goal</th>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <th>...</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <th>y_player_teammate_6</th>\n",
|
|
|
|
|
" <th>y_player_teammate_7</th>\n",
|
|
|
|
|
" <th>y_player_teammate_8</th>\n",
|
|
|
|
|
" <th>y_player_teammate_9</th>\n",
|
|
|
|
|
" <th>y_player_teammate_10</th>\n",
|
|
|
|
|
" <th>x_player_opponent_7</th>\n",
|
|
|
|
|
" <th>y_player_opponent_7</th>\n",
|
|
|
|
|
" <th>x_player_teammate_Goalkeeper</th>\n",
|
|
|
|
|
" <th>y_player_teammate_Goalkeeper</th>\n",
|
|
|
|
|
" <th>shot_kick_off</th>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>0</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>0</td>\n",
|
|
|
|
|
" <td>Right Center Forward</td>\n",
|
|
|
|
|
" <td>Right Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Right Center Forward</td>\n",
|
|
|
|
|
" <td>Left Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>2</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Center Midfield</td>\n",
|
|
|
|
|
" <td>Right Foot</td>\n",
|
|
|
|
|
" <td>Half Volley</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>True</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>3</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Left Center Midfield</td>\n",
|
|
|
|
|
" <td>Right Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>4</th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>5</td>\n",
|
|
|
|
|
" <td>Right Center Back</td>\n",
|
|
|
|
|
" <td>Left Foot</td>\n",
|
|
|
|
|
" <td>Normal</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>Open Play</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>True</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <td>...</td>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>NaN</td>\n",
|
|
|
|
|
" <td>False</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"<p>5 rows × 66 columns</p>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"</div>"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" minute position_name shot_body_part_name shot_technique_name \\\n",
|
|
|
|
|
"0 0 Right Center Forward Right Foot Normal \n",
|
|
|
|
|
"1 5 Right Center Forward Left Foot Normal \n",
|
|
|
|
|
"2 5 Center Midfield Right Foot Half Volley \n",
|
|
|
|
|
"3 5 Left Center Midfield Right Foot Normal \n",
|
|
|
|
|
"4 5 Right Center Back Left Foot Normal \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
|
|
|
|
|
"0 Open Play False False False \n",
|
|
|
|
|
"1 Open Play False False False \n",
|
|
|
|
|
"2 Open Play True False False \n",
|
|
|
|
|
"3 Open Play False False False \n",
|
|
|
|
|
"4 Open Play True False False \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" shot_deflected shot_open_goal ... y_player_teammate_6 \\\n",
|
|
|
|
|
"0 False False ... NaN \n",
|
|
|
|
|
"1 False False ... NaN \n",
|
|
|
|
|
"2 False False ... NaN \n",
|
|
|
|
|
"3 False False ... NaN \n",
|
|
|
|
|
"4 False False ... NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" y_player_teammate_7 y_player_teammate_8 y_player_teammate_9 \\\n",
|
|
|
|
|
"0 NaN NaN NaN \n",
|
|
|
|
|
"1 NaN NaN NaN \n",
|
|
|
|
|
"2 NaN NaN NaN \n",
|
|
|
|
|
"3 NaN NaN NaN \n",
|
|
|
|
|
"4 NaN NaN NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" y_player_teammate_10 x_player_opponent_7 y_player_opponent_7 \\\n",
|
|
|
|
|
"0 NaN NaN NaN \n",
|
|
|
|
|
"1 NaN NaN NaN \n",
|
|
|
|
|
"2 NaN NaN NaN \n",
|
|
|
|
|
"3 NaN NaN NaN \n",
|
|
|
|
|
"4 NaN NaN NaN \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" x_player_teammate_Goalkeeper y_player_teammate_Goalkeeper shot_kick_off \n",
|
|
|
|
|
"0 NaN NaN False \n",
|
|
|
|
|
"1 NaN NaN False \n",
|
|
|
|
|
"2 NaN NaN False \n",
|
|
|
|
|
"3 NaN NaN False \n",
|
|
|
|
|
"4 NaN NaN False \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"[5 rows x 66 columns]"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 4,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"df.head()"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Data preparation"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 5,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"outputs": [],
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Change the type of categorical features to 'category' \n",
|
|
|
|
|
"df[['position_name', \n",
|
|
|
|
|
" 'shot_technique_name', \n",
|
|
|
|
|
" 'shot_type_name', \n",
|
|
|
|
|
" 'number_of_players_opponents', \n",
|
|
|
|
|
" 'number_of_players_teammates', \n",
|
|
|
|
|
" 'shot_body_part_name']] = df[['position_name', \n",
|
|
|
|
|
" 'shot_technique_name', \n",
|
|
|
|
|
" 'shot_type_name', \n",
|
|
|
|
|
" 'number_of_players_opponents', \n",
|
|
|
|
|
" 'number_of_players_teammates', \n",
|
|
|
|
|
" 'shot_body_part_name']].astype('category')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 6,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Splitting the dataset into features (X) and the target variable (y)\n",
|
|
|
|
|
"y = pd.DataFrame(df['is_goal'])\n",
|
|
|
|
|
"X = df.drop(['is_goal'], axis=1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Splitting the data into a training set and a test set\n",
|
|
|
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Create cross-validation \n",
|
|
|
|
|
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 7,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"name": "stdout",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"Shots attempted in the training set: 27085\n",
|
|
|
|
|
"Goals scored in the training set: 3588\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"count_class_0, count_class_1 = y_train.value_counts()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Display the count of shots attempted in the training set\n",
|
|
|
|
|
"print('Shots attempted in the training set:', count_class_0)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Display the count of successful goals in the training set\n",
|
|
|
|
|
"print('Goals scored in the training set:', count_class_1)"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 8,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" Class imbalance in training data: 7.549\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Class imbalance in training data\n",
|
|
|
|
|
"scale_pos_weight = count_class_0 / count_class_1\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Training XGBoost model "
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 9,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Define the xgboost model\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 10,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Defining the hyper-parameter grid for XG Boost\n",
|
|
|
|
|
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
|
|
|
|
|
" 'max_depth': [3, 5, 7, 8, 9],\n",
|
|
|
|
|
" 'n_estimators': [100, 150, 200, 250, 300],\n",
|
|
|
|
|
" 'scale_pos_weight': [1, scale_pos_weight]}"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 11,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Starting the timer\n",
|
|
|
|
|
"start_time = time.time()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Perform grid search with cross-validation\n",
|
|
|
|
|
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
|
|
|
|
"# Fit the best model on the entire training set\n",
|
|
|
|
|
"grid_xg.fit(X_train, y_train)\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"\n",
|
|
|
|
|
"# Take the best parameters for xgboost model\n",
|
|
|
|
|
"best_xgb_model = grid_xg.best_estimator_\n",
|
|
|
|
|
"\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"# Stopping the timer\n",
|
|
|
|
|
"stop_time = time.time()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Training Time\n",
|
|
|
|
|
"xgb_training_time = stop_time - start_time"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 12,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"Best parameters: {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
|
|
|
|
|
"Model Training Time: 1916.225 seconds\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Print the best parameters and training time\n",
|
|
|
|
|
"print(\"Best parameters: \", grid_xg.best_params_)\n",
|
|
|
|
|
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Model evaluation"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Training set"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 13,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABKMElEQVR4nO3deVwV9f7H8fcB5YALoCEC7ksuuJYakqmZJK6JS7ndRFNLL5aKW1ZXUSvSLJfcskWstGtWWmlppCmZWKaRS2pqmmmCuwQqIMzvD3+e6wkX8JyRo76e9zGPe893vmfmO4M+7tvP9zuDxTAMQwAAAC7MraAHAAAAcD0EFgAA4PIILAAAwOURWAAAgMsjsAAAAJdHYAEAAC6PwAIAAFwegQUAALg8AgsAAHB5BBbcMvbs2aNWrVrJx8dHFotFy5Ytc+rxDxw4IIvFori4OKce91b24IMP6sEHHyzoYZiCnzdwayGwIF/27dunp556SpUrV5anp6e8vb3VpEkTTZ8+XefOnTP13JGRkdq2bZteeuklvf/++2rYsKGp57uZ+vTpI4vFIm9v7yvexz179shischisWjKlCn5Pv5ff/2lmJgYJSUlOWG05oqJibFd67U2VwxSOTk5eu+99xQSEqKSJUuqePHiqlatmnr37q2NGzfm+3hnz55VTEyM1q5d6/zBAreYQgU9ANw6VqxYoUcffVRWq1W9e/dW7dq1lZmZqfXr12vkyJHasWOH5s2bZ8q5z507p8TERD3//PMaPHiwKeeoUKGCzp07p8KFC5ty/OspVKiQzp49qy+++EKPPfaY3b6FCxfK09NT58+fv6Fj//XXXxo/frwqVqyo+vXr5/l7X3/99Q2dzxGdO3dW1apVbZ/T0tI0aNAgderUSZ07d7a1ly5d2qHzmPHzfuaZZzRr1ix17NhRvXr1UqFChbR792599dVXqly5sho3bpyv4509e1bjx4+XJJcMaMDNRGBBnuzfv1/du3dXhQoVtGbNGgUGBtr2RUVFae/evVqxYoVp5z927JgkydfX17RzWCwWeXp6mnb867FarWrSpIk+/PDDXIFl0aJFateunT755JObMpazZ8+qSJEi8vDwuCnnu1zdunVVt25d2+fjx49r0KBBqlu3rv71r39d9Xvnz5+Xh4eH3NzyVjh29s87JSVFs2fP1oABA3IF92nTptn+DAO4MUwJIU8mT56stLQ0vfPOO3Zh5ZKqVatqyJAhts8XLlzQxIkTVaVKFVmtVlWsWFHPPfecMjIy7L5XsWJFtW/fXuvXr9d9990nT09PVa5cWe+9956tT0xMjCpUqCBJGjlypCwWiypWrCjp4lTKpf99uUvTCpeLj4/XAw88IF9fXxUrVkzVq1fXc889Z9t/tTUNa9asUdOmTVW0aFH5+vqqY8eO2rlz5xXPt3fvXvXp00e+vr7y8fFR3759dfbs2avf2H/o2bOnvvrqK50+fdrWtmnTJu3Zs0c9e/bM1f/kyZMaMWKE6tSpo2LFisnb21tt2rTRL7/8Yuuzdu1aNWrUSJLUt29f25TKpet88MEHVbt2bW3evFnNmjVTkSJFbPfln2tYIiMj5enpmev6w8PDVaJECf311195vlZHrF27VhaLRf/973/1wgsvqEyZMipSpIhSU1PzdE+kK/+8+/Tpo2LFiunw4cOKiIhQsWLFVKpUKY0YMULZ2dnXHNP+/ftlGIaaNGmSa5/FYpG/v79d2+nTpzV06FCVK1dOVqtVVatW1aRJk5STk2MbX6lSpSRJ48ePt/3cYmJibuCOAbc+KizIky+++EKVK1fW/fffn6f+/fv314IFC9S1a1cNHz5cP/zwg2JjY7Vz504tXbrUru/evXvVtWtX9evXT5GRkXr33XfVp08fNWjQQLVq1VLnzp3l6+urYcOGqUePHmrbtq2KFSuWr/Hv2LFD7du3V926dTVhwgRZrVbt3btX33///TW/980336hNmzaqXLmyYmJidO7cOb3xxhtq0qSJtmzZkissPfbYY6pUqZJiY2O1ZcsWvf322/L399ekSZPyNM7OnTtr4MCB+vTTT/XEE09IulhdqVGjhu69995c/X///XctW7ZMjz76qCpVqqSUlBS9+eabat68uX799VcFBQWpZs2amjBhgsaOHasnn3xSTZs2lSS7n+WJEyfUpk0bde/eXf/617+uOt0yffp0rVmzRpGRkUpMTJS7u7vefPNNff3113r//fcVFBSUp+t0lokTJ8rDw0MjRoxQRkaGPDw89Ouvv173nlxLdna2wsPDFRISoilTpuibb77Ra6+9pipVqmjQoEFX/d6lUL1kyRI9+uijKlKkyFX7nj17Vs2bN9fhw4f11FNPqXz58tqwYYPGjBmjI0eOaNq0aSpVqpTmzJmTazrs8uoTcEcxgOs4c+aMIcno2LFjnvonJSUZkoz+/fvbtY8YMcKQZKxZs8bWVqFCBUOSkZCQYGs7evSoYbVajeHDh9va9u/fb0gyXn31VbtjRkZGGhUqVMg1hnHjxhmX//GeOnWqIck4duzYVcd96Rzz58+3tdWvX9/w9/c3Tpw4YWv75ZdfDDc3N6N37965zvfEE0/YHbNTp07GXXfdddVzXn4dRYsWNQzDMLp27Wq0bNnSMAzDyM7ONgICAozx48df8R6cP3/eyM7OznUdVqvVmDBhgq1t06ZNua7tkubNmxuSjLlz515xX/Pmze3aVq1aZUgyXnzxReP33383ihUrZkRERFz3Gm/UsWPHDEnGuHHjbG3ffvutIcmoXLmycfbsWbv+eb0nV/p5R0ZGGpLs+hmGYdxzzz1GgwYNrjvW3r17G5KMEiVKGJ06dTKmTJli7Ny5M1e/iRMnGkWLFjV+++03u/Znn33WcHd3Nw4ePHjVawfuVEwJ4bpSU1MlScWLF89T/y+//FKSFB0dbdc+fPhwScq11iU4ONj2r35JKlWqlKpXr67ff//9hsf8T5fWvnz22We2kvv1HDlyRElJSerTp49Klixpa69bt64efvhh23VebuDAgXafmzZtqhMnTtjuYV707NlTa9euVXJystasWaPk5OQrTgdJF9e9XFqzkZ2drRMnTtimu7Zs2ZLnc1qtVvXt2zdPfVu1aqWnnnpKEyZMUOfOneXp6ak333wzz+dypsjISHl5edm1OeOeXOnnmJc/j/Pnz9fMmTNVqVIlLV26VCNGjFDNmjXVsmVLHT582NZvyZIlatq0qUqUKKHjx4/btrCwMGVnZyshISFP4wTuJAQWXJe3t7ck6e+//85T/z/++ENubm52T3pIUkBAgHx9ffXHH3/YtZcvXz7XMUqUKKFTp07d4Ihz69atm5o0aaL+/furdOnS6t69uz766KNrhpdL46xevXqufTVr1tTx48eVnp5u1/7PaylRooQk5eta2rZtq+LFi2vx4sVauHChGjVqlOteXpKTk6OpU6fq7rvvltVqlZ+fn0qVKqWtW7fqzJkzeT5nmTJl8rXAdsqUKSpZsqSSkpI0Y8aMXOszruTYsWNKTk62bWlpaXk+39VUqlQpV5uj98TT09O2duSSvP55dHNzU1RUlDZv3qzjx4/rs88+U5s2bbRmzRp1797d1m/Pnj1auXKlSpUqZbeFhYVJko4ePXrdcwF3GgILrsvb21tBQUHavn17vr73z0WvV+Pu7n7FdsMwbvgc/1wg6eXlpYSEBH3zzTd6/PHHtXXrVnXr1k0PP/zwdRdT5ocj13KJ1WpV586dtWDBAi1duvSq1RVJevnllxUdHa1mzZrpgw8+0KpVqxQfH69atWrluZIkKVeV4np+/vln2/+pbtu2LU/fadSokQIDA23bjbxP5p+uNG5H78nVfob5ddddd+mRRx7Rl19+qebNm2v9+vW2EJyTk6OHH35Y8fHxV9y6dOnilDEAtxMW3SJP2rdvr3nz5ikxMVGhoaHX7FuhQgXl5ORoz549qlmzpq09JSVFp0+fti1OdIYSJUrYPVFzyT+rONLFf/22bNlSLVu21Ouvv66XX35Zzz//vL799lvbv2z/eR2StHv37lz7du3aJT8/PxUtWtTxi7i
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 2 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Evaluate the model on training set\n",
|
|
|
|
|
"y_pred_train = best_xgb_model.predict(X_train)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Confusion Matrix for Training Data\n",
|
|
|
|
|
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
|
|
|
|
|
"ax.set_xlabel('Predicted')\n",
|
|
|
|
|
"ax.set_ylabel('Actual')\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"ax.set_title('Confusion Matrix - Train Set')\n",
|
|
|
|
|
"plt.show()"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Test set"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 14,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAHHCAYAAACPy0PBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABJt0lEQVR4nO3deVxU1f/H8feAMiAKiMiWG2Yu5FJiKbknSYbmmlmWuFUamuJOm6YlZYtpbu1YaWn11Uxz+2pqJW4UZW65FZqCW4grKNzfH/6cryPqgM11FF/PHvfxaO49c++54/bmc865YzEMwxAAAIALubm6AwAAAAQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQSAADgcgQS3LC2b9+uli1bytfXVxaLRXPnznXq+f/8809ZLBYlJSU59bw3smbNmqlZs2au7gaAIohAgn9l586deuqpp1S5cmV5enrKx8dHDRs21IQJE3Tq1ClTrx0bG6uNGzfqlVde0aeffqp69eqZer1rqXv37rJYLPLx8bnk57h9+3ZZLBZZLBa98cYbhT7/vn37NGrUKKWmpjqht+YaNWqU7V6vtDkrKH333XcaNWpUgdvn5eXpk08+Uf369eXv769SpUqpatWq6tatm9asWVPo6588eVKjRo3SihUrCv1e4EZWzNUdwI1rwYIFeuihh2S1WtWtWzfVrFlTOTk5+vHHHzV06FBt2rRJ7733ninXPnXqlJKTk/Xcc8+pX79+plyjYsWKOnXqlIoXL27K+R0pVqyYTp48qW+//VadO3e2OzZjxgx5enrq9OnTV3Xuffv26aWXXlKlSpV0xx13FPh9S5Ysuarr/RsdOnRQlSpVbK+PHz+uvn37qn379urQoYNtf1BQkFOu991332ny5MkFDiXPPPOMJk+erLZt26pr164qVqyYtm3bpoULF6py5cpq0KBBoa5/8uRJvfTSS5JENQo3FQIJrsru3bvVpUsXVaxYUcuXL1dISIjtWFxcnHbs2KEFCxaYdv2DBw9Kkvz8/Ey7hsVikaenp2nnd8Rqtaphw4b6/PPP8wWSmTNnKiYmRl9//fU16cvJkydVokQJeXh4XJPrXah27dqqXbu27fWhQ4fUt29f1a5dW4899tg178+FMjIyNGXKFD3xxBP5wvfbb79t+30KwDGGbHBVxo0bp+PHj+vDDz+0CyPnValSRQMGDLC9Pnv2rMaMGaNbb71VVqtVlSpV0rPPPqvs7Gy791WqVEmtW7fWjz/+qLvvvluenp6qXLmyPvnkE1ubUaNGqWLFipKkoUOHymKxqFKlSpLODXWc//8LnS/7X2jp0qVq1KiR/Pz8VLJkSVWrVk3PPvus7fjl5pAsX75cjRs3lre3t/z8/NS2bVtt2bLlktfbsWOHunfvLj8/P/n6+qpHjx46efLk5T/Yizz66KNauHChMjMzbfvWr1+v7du369FHH83X/siRIxoyZIhq1aqlkiVLysfHR61atdKvv/5qa7NixQrdddddkqQePXrYhjzO32ezZs1Us2ZNpaSkqEmTJipRooTtc7l4DklsbKw8PT3z3X90dLRKly6tffv2Ffhe/62tW7eqU6dO8vf3l6enp+rVq6d58+bZtTlz5oxeeukl3XbbbfL09FSZMmXUqFEjLV26VNK53z+TJ0+WJLvhoMvZvXu3DMNQw4YN8x2zWCwKDAy025eZmamBAweqfPnyslqtqlKlil577TXl5eVJOvd7rmzZspKkl156yXb9wgwhATcqKiS4Kt9++60qV66se+65p0Dte/furenTp6tTp04aPHiw1q5dq8TERG3ZskVz5syxa7tjxw516tRJvXr1UmxsrD766CN1795dERERuv3229WhQwf5+fkpPj5ejzzyiB544AGVLFmyUP3ftGmTWrdurdq1a2v06NGyWq3asWOHfvrppyu+77///a9atWqlypUra9SoUTp16pTeeecdNWzYUD///HO+MNS5c2eFhYUpMTFRP//8sz744AMFBgbqtddeK1A/O3TooD59+ug///mPevbsKelcdaR69eqqW7duvva7du3S3Llz9dBDDyksLEwZGRl699131bRpU23evFmhoaGqUaOGRo8erRdffFFPPvmkGjduLEl2v5aHDx9Wq1at1KVLFz322GOXHQ6ZMGGCli9frtjYWCUnJ8vd3V3vvvuulixZok8//VShoaEFus9/a9OmTWrYsKFuueUWjRgxQt7e3po9e7batWunr7/+Wu3bt5d0LigmJiaqd+/euvvuu5WVlaUNGzbo559/1n333aennnpK+/bt09KlS/Xpp586vO75YPzll1/qoYceUokSJS7b9uTJk2ratKn+/vtvPfXUU6pQoYJWr16thIQE7d+/X2+//bbKli2rqVOn5huSurBCBBRZBlBIR48eNSQZbdu2LVD71NRUQ5LRu3dvu/1DhgwxJBnLly+37atYsaIhyVi1apVt34EDBwyr1WoMHjzYtm/37t2GJOP111+3O2dsbKxRsWLFfH0YOXKkceFv9/HjxxuSjIMHD1623+ev8fHHH9v23XHHHUZgYKBx+PBh275ff/3VcHNzM7p165bvej179rQ7Z/v27Y0yZcpc9poX3oe3t7dhGIbRqVMno0WLFoZhGEZubq4RHBxsvPTSS5f8DE6fPm3k5ubmuw+r1WqMHj3atm/9+vX57u28pk2bGpKMadOmXfJY06ZN7fYtXrzYkGS8/PLLxq5du4ySJUsa7dq1c3iPV+vgwYOGJGPkyJG2fS1atDBq1aplnD592rYvLy/PuOeee4zbbrvNtq9OnTpGTEzMFc8fFxdnFOavxm7duhmSjNKlSxvt27c33njjDWPLli352o0ZM8bw9vY2/vjjD7v9I0aMMNzd3Y20tLTL3h9wM2DIBoWWlZUlSSpVqlSB2n/33XeSpEGDBtntHzx4sCTlm2sSHh5u+6ldksqWLatq1app165dV93ni52fe/LNN9/YyuWO7N+/X6mpqerevbv8/f1t+2vXrq377rvPdp8X6tOnj93rxo0b6/Dhw7bPsCAeffRRrVixQunp6Vq+fLnS09MvOVwjnZt34uZ27o91bm6uDh8+bBuO+vnnnwt8TavVqh49ehSobcuWLfXUU09p9OjR6tChgzw9PfXuu+8W+Fr/1pEjR7R8+XJ17txZx44d06FDh3To0CEdPnxY0dHR2r59u/7++29J537dN23apO3btzvt+h9//LEmTZqksLAwzZkzR0OGDFGNGjXUokUL23Wlc1WUxo0bq3Tp0rY+Hjp0SFFRUcrNzdWqVauc1ifgRkQgQaH5+PhIko4dO1ag9n/99Zfc3NzsVkpIUnBwsPz8/PTXX3/Z7a9QoUK+c5QuXVr//PPPVfY4v4cfflgNGzZU7969FRQUpC5dumj27NlXDCfn+1mtWrV8x2rUqKFDhw7pxIkTdvsvvpfSpUtLUqHu5YEHHlCpUqU0a9YszZgxQ3fddVe+z/K8vLw8jR8/XrfddpusVqsCAgJUtmxZ/fbbbzp69GiBr3nLLbcUagLrG2+8IX9/f6WmpmrixIn55k5cysGDB5Wenm7bjh8/XuDrXWjHjh0yDEMvvPCCypYta7eNHDlSknTgwAFJ0ujRo5WZmamqVauqVq1aGjp0qH777beruu55bm5uiouLU0pKig4dOqRvvvlGrVq10vLly9WlSxdbu+3bt2vRokX5+hgVFWXXR+BmxRwSFJqPj49CQ0P1+++/F+p9V5oceCF3d/dL7jcM46qvkZuba/fay8tLq1at0vfff68FCxZo0aJFmjVrlu69914tWbLksn0orH9zL+dZrVZ16NBB06dP165du644wXHs2LF64YUX1LNnT40ZM0b+/v5yc3PTwIEDC1wJks59PoXxyy+/2P5B3bhxox555BGH77nrrrvswujIkSOvavLm+fsaMmSIoqOjL9nmfIBr0qSJdu7cqW+++UZLlizRBx98oPHjx2vatGnq3bt3oa99sTJlyujBBx/Ugw8+qGbNmmnlypX666+/VLFiReXl5em+++7
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 2 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Evaluate the model on test set\n",
|
|
|
|
|
"y_pred_test = best_xgb_model.predict(X_test)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
|
|
|
|
"# Confusion Matrix for Testig Data\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
|
|
|
|
|
"ax.set_xlabel('Predicted')\n",
|
|
|
|
|
"ax.set_ylabel('Actual')\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"ax.set_title('Confusion Matrix - Test Set')\n",
|
|
|
|
|
"plt.show()"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 15,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"The test dataset contains 7669 shots, with 914 of them being goals.\n"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Number of goals in test set\n",
|
|
|
|
|
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"cell_type": "markdown",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Feature importance"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 16,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAHHCAYAAAAs6rBrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gUx//A8ffRuwIBBUVERcAG2MWCHUSxRbEr1p+9xV4oVuw9xo5GjSWxJSqKCsaKiMYuVqKxxC4CSrv5/eHDfr0cIKaJcV7Pw6M3Mzs7+9mDu9mZnVUJIQSSJEmSJEmSJElSvqTzsRsgSZIkSZIkSZIk5Ux22iRJkiRJkiRJkvIx2WmTJEmSJEmSJEnKx2SnTZIkSZIkSZIkKR+TnTZJkiRJkiRJkqR8THbaJEmSJEmSJEmS8jHZaZMkSZIkSZIkScrHZKdNkiRJkiRJkiQpH5OdNkmSJEmSJEmSpHxMdtokSZIk6RMSHh6OSqUiISHhYzdFkiRJ+pfITpskSZKUr2V1UrL7GTNmzD+yz+PHjxMSEsKLFy/+kfo/ZykpKYSEhBAdHf2xmyJJkvTJ0PvYDZAkSZKkvJg0aRJOTk4aaeXKlftH9nX8+HFCQ0MJDAykYMGC/8g+/qwuXbrQvn17DA0NP3ZT/pSUlBRCQ0MBqFu37sdtjCRJ0idCdtokSZKkT0KTJk2oXLnyx27GX5KcnIypqelfqkNXVxddXd2/qUX/HrVaTVpa2sduhiRJ0idJTo+UJEmS/hP27t1L7dq1MTU1xdzcnKZNm3Lp0iWNMufPnycwMJASJUpgZGRE4cKF6dGjB0+fPlXKhISEMHLkSACcnJyUqZgJCQkkJCSgUqkIDw/X2r9KpSIkJESjHpVKxeXLl+nYsSOWlpbUqlVLyV+/fj2VKlXC2NgYKysr2rdvz927d997nNnd01a8eHGaNWtGdHQ0lStXxtjYmPLlyytTELdt20b58uUxMjKiUqVKnD17VqPOwMBAzMzMuHXrFj4+PpiammJvb8+kSZMQQmiUTU5O5quvvsLBwQFDQ0NcXFyYPXu2VjmVSsXAgQPZsGEDZcuWxdDQkG+++QYbGxsAQkNDldhmxS0v5+fd2N64cUMZDS1QoADdu3cnJSVFK2br16+natWqmJiYYGlpSZ06ddi/f79Gmby8fyRJkj4WOdImSZIkfRJevnzJkydPNNK++OILAL799lu6deuGj48PM2bMICUlhaVLl1KrVi3Onj1L8eLFAYiMjOTWrVt0796dwoULc+nSJZYvX86lS5c4efIkKpWK1q1bc+3aNb777jvmzZun7MPGxobHjx9/cLvbtm2Ls7Mz06ZNUzo2U6dOZeLEiQQEBNCrVy8eP37MokWLqFOnDmfPnv1TUzJv3LhBx44d+b//+z86d+7M7Nmz8ff355tvvmHcuHH0798fgOnTpxMQEEB8fDw6Ov+7dpuZmYmvry/Vq1dn5syZREREEBwcTEZGBpMmTQJACEHz5s2JioqiZ8+eeHh4sG/fPkaOHMm9e/eYN2+eRpsOHTrEli1bGDhwIF988QXu7u4sXbqUfv360apVK1q3bg1AhQoVgLydn3cFBATg5OTE9OnTOXPmDCtXrsTW1pYZM2YoZUJDQwkJCcHLy4tJkyZhYGBATEwMhw4donHjxkDe3z+SJEkfjZAkSZKkfGzNmjUCyPZHCCFevXolChYsKHr37q2x3cOHD0WBAgU00lNSUrTq/+677wQgfv75ZyVt1qxZAhC3b9/WKHv79m0BiDVr1mjVA4jg4GDldXBwsABEhw4dNMolJCQIXV1dMXXqVI30CxcuCD09Pa30nOLxbtscHR0FII4fP66k7du3TwDC2NhY/Prrr0r6smXLBCCioqKUtG7duglADBo0SElTq9WiadOmwsDAQDx+/FgIIcSOHTsEIKZMmaLRpjZt2giVSiVu3LihEQ8dHR1x6dIljbKPHz/WilWWvJ6frNj26NFDo2yrVq2EtbW18vr69etCR0dHtGrVSmRmZmqUVavVQogPe/9IkiR9LHJ6pCRJkvRJWLJkCZGRkRo/8HZ05sWLF3To0IEnT54oP7q6ulSrVo2oqCilDmNjY+X/b9684cmTJ1SvXh2AM2fO/CPt7tu3r8brbdu2oVarCQgI0Ghv4cKFcXZ21mjvhyhTpgw1atRQXlerVg2A+vXrU6xYMa30W7duadUxcOBA5f9Z0xvT0tI4cOAAAHv27EFXV5fBgwdrbPfVV18hhGDv3r0a6d7e3pQpUybPx/Ch5+ePsa1duzZPnz4lMTERgB07dqBWqwkKCtIYVcw6Pviw948kSdLHIqdHSpIkSZ+EqlWrZrsQyfXr14G3nZPsWFhYKP9/9uwZoaGhbNq0iUePHmmUe/ny5d/Y2v/544qX169fRwiBs7NztuX19fX/1H7e7ZgBFChQAAAHB4ds058/f66RrqOjQ4kSJTTSSpcuDaDcP/frr79ib2+Pubm5Rjk3Nzcl/11/PPb3+dDz88djtrS0BN4em4WFBTdv3kRHRyfXjuOHvH8kSZI+FtlpkyRJkj5parUaeHtfUuHChbXy9fT+91EXEBDA8ePHGTlyJB4eHpiZmaFWq/H19VXqyc0f76nKkpmZmeM2744eZbVXpVKxd+/ebFeBNDMze287spPTipI5pYs/LBzyT/jjsb/Ph56fv+PYPuT9I0mS9LHIv0SSJEnSJ61kyZIA2Nra0rBhwxzLPX/+nIMHDxIaGkpQUJCSnjXS8q6cOmdZIzl/fOj2H0eY3tdeIQROTk7KSFZ+oFaruXXrlkabrl27BqAsxOHo6MiBAwd49eqVxmjb1atXlfz3ySm2H3J+8qpkyZKo1WouX76Mh4dHjmXg/e8fSZKkj0ne0yZJkiR90nx8fLCwsGDatGmkp6dr5Wet+Jg1KvPHUZj58+drbZP1LLU/ds4sLCz44osv+PnnnzXSv/766zy3t3Xr1ujq6hIaGqrVFiGE1vL2/6bFixdrtGXx4sXo6+vToEEDAPz8/MjMzNQoBzBv3jxUKhVNmjR57z5MTEwA7dh+yPnJq5YtW6Kjo8OkSZO0Ruqy9pPX948kSdLHJEfaJEmSpE+ahYUFS5cupUuXLlSsWJH27dtjY2PDnTt32L17NzVr1mTx4sVYWFhQp04dZs6cSXp6OkWKFGH//v3cvn1bq85KlSoBMH78eNq3b4++vj7+/v6YmprSq1cvwsLC6NWrF5UrV+bnn39WRqTyomTJkkyZMoWxY8eSkJBAy5YtMTc35/bt22zfvp0+ffowYsSIvy0+eWVkZERERATdunWjWrVq7N27l927dzNu3Djl2Wr+/v7Uq1eP8ePHk5CQgLu7O/v372fnzp0MHTpUGbXKjbGxMWXKlGHz5s2ULl0aKysrypUrR7ly5fJ8fvKqVKlSjB8/nsmTJ1O7dm1at26NoaEhsbGx2NvbM3369Dy/fyRJkj4m2WmTJEmSPnkdO3bE3t6esLAwZs2aRWpqKkWKFKF27dp0795dKbdx40YGDRrEkiVLEELQuHFj9u7di729vUZ9VapUYfLkyXzzzTdERESgVqu5ffs2pqamBAUF8fjxY77//nu2bNlCkyZN2Lt3L7a2tnlu75gxYyhdujTz5s0jNDQUeLtgSOPGjWnevPnfE5QPpKurS0REBP369WPkyJGYm5sTHBysMVVRR0eHXbt2ERQUxObNm1mzZg3Fixdn1qxZfPXVV3ne18qVKxk0aBDDhg0jLS2N4OBgypUrl+fz8yEmTZqEk5MTixYtYvz48ZiYmFChQgW6dOmilMnr+0eSJOljUYl/405kSZIkSZLyrcDAQL7//nuSkpI+dlMkSZKkbMh72iRJkiRJkiRJkvIx2WmTJEmSJEmSJEnKx2SnTZIkSZIkSZIkKR+T97RJkiRJkiRJkiTlY3KkTZIkSZIkSZIkKR+TnTZJkiRJkiRJkqR8TD6nTZI+c2q1mvv372Nubo5KpfrYzZEkSZIkKQ+EELx69Qp7e3t0dOQ4zH+d7LRJ0mfu/v37ODg4fOxmSJIkSZL0J9y9e5eiRYt+7GZI/zDZaZO
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Plot feature importance with Gain\n",
|
|
|
|
|
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"plt.show()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 17,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAHHCAYAAAAMBu+WAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxO6f/48dfdvq+KMlFIhSQ72UXJGHt2MmGMQox1bGUrex/LYCxlHbPZZqwhWwgZxpCsyaDJkki0nt8f/TpftxYh6309H4/7wX3Oda5zvc9d3ec65zrXWyFJkoQgCIIgCIIgCCpF7UM3QBAEQRAEQRCE9090BARBEARBEARBBYmOgCAIgiAIgiCoINEREARBEARBEAQVJDoCgiAIgiAIgqCCREdAEARBEARBEFSQ6AgIgiAIgiAIggoSHQFBEARBEARBUEGiIyAIgiAIgiAIKkh0BARBEAThExIeHo5CoSA+Pv5DN0UQhE+c6AgIgiAIH7W8E9+CXuPGjXsn+zx27BiBgYE8evTondSvytLS0ggMDOTgwYMfuimCoPI0PnQDBEEQBKE4pk6dip2dndKyatWqvZN9HTt2jKCgIHx8fDAxMXkn+3hTffr0oXv37mhra3/opryRtLQ0goKCAGjWrNmHbYwgqDjRERAEQRA+CW3atKF27dofuhlv5enTp+jr679VHerq6qirq5dQi96fnJwcMjIyPnQzBEF4gRgaJAiCIHwWdu3aRePGjdHX18fQ0JC2bdty4cIFpTJ///03Pj4+VKhQAR0dHcqUKcPXX3/NgwcP5DKBgYGMHj0aADs7O3kYUnx8PPHx8SgUCsLDw/PtX6FQEBgYqFSPQqHg4sWL9OzZE1NTUxo1aiSvX79+PbVq1UJXVxczMzO6d+/OrVu3XhlnQc8I2Nra8uWXX3Lw4EFq166Nrq4uzs7O8vCbzZs34+zsjI6ODrVq1eKvv/5SqtPHxwcDAwOuX7+Oh4cH+vr6WFtbM3XqVCRJUir79OlTvvvuO2xsbNDW1sbBwYG5c+fmK6dQKPD392fDhg1UrVoVbW1tli1bhoWFBQBBQUHysc07bsX5fF48tlevXpXv2hgbG9O/f3/S0tLyHbP169dTt25d9PT0MDU1pUmTJuzdu1epTHF+fgThcyPuCAiCIAifhJSUFO7fv6+0rFSpUgCsW7eOfv364eHhwaxZs0hLS2Pp0qU0atSIv/76C1tbWwAiIiK4fv06/fv3p0yZMly4cIEff/yRCxcucOLECRQKBZ06deLy5cv89NNPLFiwQN6HhYUF9+7de+12d+3aFXt7e2bOnCmfLM+YMYNJkybh7e3NgAEDuHfvHosWLaJJkyb89ddfbzQc6erVq/Ts2ZNvvvmG3r17M3fuXNq1a8eyZcv4/vvvGTJkCADBwcF4e3sTFxeHmtr/XQ/Mzs7G09OT+vXrM3v2bHbv3s2UKVPIyspi6tSpAEiSxFdffUVkZCS+vr7UqFGDPXv2MHr0aG7fvs2CBQuU2nTgwAF++eUX/P39KVWqFC4uLixdupRvv/2Wjh070qlTJwCqV68OFO/zeZG3tzd2dnYEBwdz5swZVq5ciaWlJbNmzZLLBAUFERgYSMOGDZk6dSpaWlpER0dz4MABWrduDRT/50cQPjuSIAiCIHzEwsLCJKDAlyRJ0pMnTyQTExNp4MCBStslJiZKxsbGSsvT0tLy1f/TTz9JgHT48GF52Zw5cyRAunHjhlLZGzduSIAUFhaWrx5AmjJlivx+ypQpEiD16NFDqVx8fLykrq4uzZgxQ2n5+fPnJQ0NjXzLCzseL7atfPnyEiAdO3ZMXrZnzx4JkHR1daWbN2/Ky5cvXy4BUmRkpLysX79+EiANHTpUXpaTkyO1bdtW0tLSku7duydJkiRt3bpVAqTp06crtalLly6SQqGQrl69qnQ81NTUpAsXLiiVvXfvXr5jlae4n0/esf3666+Vynbs2FEyNzeX31+5ckVSU1OTOnbsKGVnZyuVzcnJkSTp9X5+BOFzI4YGCYIgCJ+EJUuWEBERofSC3KvIjx49okePHty/f19+qaurU69ePSIjI+U6dHV15f8/f/6c+/fvU79+fQDOnDnzTto9ePBgpfebN28mJycHb29vpfaWKVMGe3t7pfa+jipVqtCgQQP5fb169QBo0aIF5cqVy7f8+vXr+erw9/eX/583tCcjI4N9+/YBsHPnTtTV1Rk2bJjSdt999x2SJLFr1y6l5U2bNqVKlSrFjuF1P5+Xj23jxo158OABjx8/BmDr1q3k5OQwefJkpbsfefHB6/38CMLnRgwNEgRBED4JdevWLfBh4StXrgC5J7wFMTIykv//8OFDgoKC2LRpE0lJSUrlUlJSSrC1/+flmY6uXLmCJEnY29sXWF5TU/ON9vPiyT6AsbExADY2NgUuT05OVlqupqZGhQoVlJZVrlwZQH4e4ebNm1hbW2NoaKhUzsnJSV7/opdjf5XX/XxejtnU1BTIjc3IyIhr166hpqZWZGfkdX5+BOFzIzoCgiAIwictJycHyB3nXaZMmXzrNTT+76vO29ubY8eOMXr0aGrUqIGBgQE5OTl4enrK9RTl5THqebKzswvd5sWr3HntVSgU7Nq1q8DZfwwMDF7ZjoIUNpNQYcullx7ufRdejv1VXvfzKYnYXufnRxA+N+KnWxAEQfikVaxYEQBLS0vc3d0LLZecnMz+/fsJCgpi8uTJ8vK8K8IvKuyEP++K88uJxl6+Ev6q9kqShJ2dnXzF/WOQk5PD9evXldp0+fJlAPlh2fLly7Nv3z6ePHmidFfg0qVL8vpXKezYvs7nU1wVK1YkJyeHixcvUqNGjULLwKt/fgThcySeERAEQRA+aR4eHhgZGTFz5kwyMzPzrc+b6Sfv6vHLV4tDQ0PzbZM31//LJ/xGRkaUKlWKw4cPKy3/4Ycfit3eTp06oa6uTlBQUL62SJKUb6rM92nx4sVKbVm8eDGampq0bNkSAC8vL7Kzs5XKASxYsACFQkGbNm1euQ89PT0g/7F9nc+nuDp06ICamhpTp07Nd0chbz/F/fkRhM+RuCMgCIIgfNKMjIxYunQpffr0oWbNmnTv3h0LCwsSEhLYsWMHbm5uLF68GCMjI5o0acLs2bPJzMykbNmy7N27lxs3buSrs1atWgBMmDCB7t27o6mpSbt27dDX12fAgAGEhIQwYMAAateuzeHDh+Ur58VRsWJFpk+fzvjx44mPj6dDhw4YGhpy48YNtmzZwqBBgxg1alSJHZ/i0tHRYffu3fTr14969eqxa9cuduzYwffffy/P/d+uXTuaN2/OhAkTiI+Px8XFhb1797Jt2zYCAgLkq+tF0dXVpUqVKvz8889UrlwZMzMzqlWrRrVq1Yr9+RRXpUqVmDBhAtOmTaNx48Z06tQJbW1tTp06hbW1NcHBwcX++RGEz9IHmq1IEARBEIolb7rMU6dOFVkuMjJS8vDwkIyNjSUdHR2pYsWKko+Pj3T69Gm5zL///it17NhRMjExkYyNjaWuXbtKd+7cKXA6y2nTpklly5aV1NTUlKbrTEtLk3x9fSVjY2PJ0NBQ8vb2lpKSkgqdPjRv6s2X/f7771KjRo0kfX19SV9fX3J0dJT8/PykuLi4Yh2Pl6cPbdu2bb6ygOTn56e0LG8K1Dlz5sjL+vXrJ+nr60vXrl2TWrduLenp6UmlS5eWpkyZkm/azSdPnkgjRoyQrK2tJU1NTcne3l6aM2eOPB1nUfvOc+zYMalWrVqSlpaW0nEr7udT2LEt6NhIkiStXr1acnV1lbS1tSVTU1OpadOmUkREhFKZ4vz8CMLnRiFJ7+FpIUEQBEEQPlo+Pj789ttvpKamfuimCILwHolnBARBEARBEARBBYmOgCAIgiAIgiCoINEREARBEARBEAQVJJ4REARBEARBEAQVJO4ICIIgCIIgCIIKEh0BQRAEQRAEQVBBIqGYIKi4nJwc7ty5g6GhIQqF4kM3RxAEQRCEYpAkiSdPnmBtbY2a2ptd2xcdAUFQcXfu3MHGxuZDN0MQBEEQhDdw69YtvvjiizfaVnQ
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Plot feature importance with Weight\n",
|
|
|
|
|
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"plt.show()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"## Summary"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 18,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Calculating MAE, RMSE and R2 for training and test sets \n",
|
|
|
|
|
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
|
|
|
|
|
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
|
|
|
|
|
"r2_train = r2_score(y_train, y_pred_train)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
|
|
|
|
|
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
|
|
|
|
|
"r2_test = r2_score(y_test, y_pred_test)"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 19,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<style type=\"text/css\">\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"#T_e5682_row0_col0, #T_e5682_row0_col1, #T_e5682_row0_col2, #T_e5682_row0_col3, #T_e5682_row0_col4, #T_e5682_row0_col5, #T_e5682_row0_col6 {\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" font-weight: bold;\n",
|
|
|
|
|
" border: 2.0px solid grey;\n",
|
|
|
|
|
" color: white;\n",
|
|
|
|
|
"}\n",
|
|
|
|
|
"</style>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"<table id=\"T_e5682\">\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th class=\"blank level0\" > </th>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <th id=\"T_e5682_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
|
|
|
|
|
" <th id=\"T_e5682_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
|
|
|
|
|
" <th id=\"T_e5682_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
|
|
|
|
|
" <th id=\"T_e5682_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
|
|
|
|
|
" <th id=\"T_e5682_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
|
|
|
|
|
" <th id=\"T_e5682_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
|
|
|
|
|
" <th id=\"T_e5682_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th class=\"index_name level0\" >Model Name</th>\n",
|
|
|
|
|
" <th class=\"blank col0\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col1\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col2\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col3\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col4\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col5\" > </th>\n",
|
|
|
|
|
" <th class=\"blank col6\" > </th>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" <th id=\"T_e5682_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col0\" class=\"data row0 col0\" >0.09934</td>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col1\" class=\"data row0 col1\" >0.31518</td>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col2\" class=\"data row0 col2\" >0.03828</td>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col3\" class=\"data row0 col3\" >0.10366</td>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col4\" class=\"data row0 col4\" >0.32197</td>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col5\" class=\"data row0 col5\" >0.01251</td>\n",
|
|
|
|
|
" <td id=\"T_e5682_row0_col6\" class=\"data row0 col6\" >31.93709</td>\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"<pandas.io.formats.style.Styler at 0x20007e22020>"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 19,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Creating of dataframe of summary results\n",
|
|
|
|
|
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
" 'Training MAE': mae_train, \n",
|
|
|
|
|
" 'Training RMSE': rmse_train,\n",
|
|
|
|
|
" 'Training R2':r2_train,\n",
|
|
|
|
|
" 'Testing MAE': mae_test, \n",
|
|
|
|
|
" 'Testing RMSE': rmse_test,\n",
|
|
|
|
|
" 'Testing R2':r2_test,\n",
|
|
|
|
|
" 'Training Time (mins)': xgb_training_time/60})\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"summary_df.set_index('Model Name', inplace=True)\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"# Displaying summary of results\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
" 'border': '2.0px solid grey','color': 'white'})"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Keeping the xgboost model"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"execution_count": 20,
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Save the model\n",
|
|
|
|
|
"dump(best_xgb_model, 'xgboost.joblib') \n",
|
2023-12-12 15:22:01 +01:00
|
|
|
|
"\n",
|
2023-12-28 23:45:45 +01:00
|
|
|
|
"# Load the model\n",
|
|
|
|
|
"model = load('xgboost.joblib')"
|
2023-12-12 15:22:01 +01:00
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.10.11"
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 2
|
|
|
|
|
}
|