fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb

1463 lines
389 KiB
Plaintext
Raw Normal View History

2023-12-12 15:22:01 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 1,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
"import xgboost\n",
"import time\n",
2024-01-12 20:39:14 +01:00
"from joblib import dump, load\n",
"import os"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the data"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 2,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-01-20 09:26:51 +01:00
"df = pd.read_csv('final_data_new.csv')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 3,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
2024-01-12 20:39:14 +01:00
" 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n",
" 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n",
" 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n",
" 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n",
" 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n",
" 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n",
" 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n",
" 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n",
" 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n",
" 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n",
" 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n",
" 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n",
" 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n",
" 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n",
" 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n",
" 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n",
" 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n",
" 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n",
" 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n",
" 'y_player_teammate_Goalkeeper'],\n",
" dtype='object')"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
2023-12-12 15:22:01 +01:00
"source": [
"df.columns"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
2024-01-12 20:39:14 +01:00
" <th>shot_follows_dribble</th>\n",
2023-12-12 15:22:01 +01:00
" <th>...</th>\n",
2024-01-12 20:39:14 +01:00
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2024-01-20 09:26:51 +01:00
" <td>2</td>\n",
" <td>Left Center Forward</td>\n",
" <td>Right Foot</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Half Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>26.6</td>\n",
" <td>53.1</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Left Back</td>\n",
" <td>Left Foot</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>20.6</td>\n",
" <td>32.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>23.8</td>\n",
" <td>31.2</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
2024-01-20 09:26:51 +01:00
" <td>15</td>\n",
" <td>Left Center Forward</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>29.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>29.6</td>\n",
" <td>55.3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
2024-01-20 09:26:51 +01:00
" <td>16</td>\n",
" <td>Center Forward</td>\n",
" <td>Head</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>26.7</td>\n",
" <td>60.4</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
2024-01-20 09:26:51 +01:00
" <td>18</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>27.9</td>\n",
" <td>31.4</td>\n",
" <td>33.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>16.9</td>\n",
" <td>40.1</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-01-12 20:39:14 +01:00
"<p>5 rows × 64 columns</p>\n",
2023-12-12 15:22:01 +01:00
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
2024-01-20 09:26:51 +01:00
"0 2 Left Center Forward Right Foot Half Volley \n",
"1 5 Left Back Left Foot Volley \n",
"2 15 Left Center Forward Left Foot Normal \n",
"3 16 Center Forward Head Normal \n",
"4 18 Right Center Forward Right Foot Normal \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
2024-01-20 09:26:51 +01:00
"0 Open Play True False False \n",
"1 Open Play True False False \n",
"2 Open Play False False False \n",
"3 Open Play False False True \n",
"4 Open Play False False False \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
2024-01-20 09:26:51 +01:00
"1 False False ... 20.6 \n",
"2 False False ... 29.0 \n",
2024-01-12 20:39:14 +01:00
"3 False False ... NaN \n",
2024-01-20 09:26:51 +01:00
"4 False False ... 27.9 \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
2024-01-20 09:26:51 +01:00
"1 32.8 NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
2024-01-20 09:26:51 +01:00
"4 31.4 33.4 NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
2024-01-20 09:26:51 +01:00
"0 NaN NaN 26.6 \n",
"1 NaN NaN 23.8 \n",
"2 NaN NaN 29.6 \n",
"3 NaN NaN 26.7 \n",
"4 NaN NaN 16.9 \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
2024-01-20 09:26:51 +01:00
"0 53.1 NaN \n",
"1 31.2 NaN \n",
"2 55.3 NaN \n",
"3 60.4 NaN \n",
"4 40.1 NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"\n",
"[5 rows x 64 columns]"
2023-12-12 15:22:01 +01:00
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
2023-12-12 15:22:01 +01:00
"source": [
"# df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']] = df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']].astype(float)"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OrdinalEncoder\n",
"\n",
"enc = OrdinalEncoder()\n",
"\n",
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = enc.fit_transform(df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name',\n",
" 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']].astype(int)"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
" <th>shot_follows_dribble</th>\n",
" <th>...</th>\n",
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2024-01-20 17:18:19 +01:00
" <td>2</td>\n",
" <td>9</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>2</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>26.6</td>\n",
" <td>53.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
2024-01-20 17:18:19 +01:00
" <td>7</td>\n",
" <td>1</td>\n",
2024-01-20 17:18:19 +01:00
" <td>6</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
2024-01-20 17:18:19 +01:00
" <td>20.6</td>\n",
" <td>32.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>23.8</td>\n",
" <td>31.2</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
2024-01-20 17:18:19 +01:00
" <td>15</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
2024-01-20 17:18:19 +01:00
" <td>29.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>29.6</td>\n",
" <td>55.3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
2024-01-20 17:18:19 +01:00
" <td>16</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>0</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2024-01-20 17:18:19 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>26.7</td>\n",
" <td>60.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
2024-01-20 17:18:19 +01:00
" <td>18</td>\n",
" <td>18</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
2024-01-20 17:18:19 +01:00
" <td>27.9</td>\n",
" <td>31.4</td>\n",
" <td>33.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>16.9</td>\n",
" <td>40.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
2024-01-20 17:18:19 +01:00
" <th>82816</th>\n",
" <td>79</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>2</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
2024-01-20 17:18:19 +01:00
" <td>30.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>NaN</td>\n",
" <td>30.8</td>\n",
" <td>40.3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
2024-01-20 17:18:19 +01:00
" <th>82817</th>\n",
" <td>80</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
2024-01-20 17:18:19 +01:00
" <td>60.2</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>NaN</td>\n",
" <td>31.9</td>\n",
" <td>47.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
2024-01-20 17:18:19 +01:00
" <th>82818</th>\n",
" <td>82</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
2024-01-20 17:18:19 +01:00
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
2024-01-20 17:18:19 +01:00
" <th>82819</th>\n",
" <td>84</td>\n",
" <td>21</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
2024-01-20 17:18:19 +01:00
" <th>82820</th>\n",
" <td>88</td>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 17:18:19 +01:00
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>20.0</td>\n",
" <td>44.5</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-01-20 17:18:19 +01:00
"<p>82821 rows × 64 columns</p>\n",
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
2024-01-20 17:18:19 +01:00
"0 2 9 3 2 \n",
"1 5 7 1 6 \n",
"2 15 9 1 4 \n",
"3 16 3 0 4 \n",
"4 18 18 3 4 \n",
"... ... ... ... ... \n",
2024-01-20 17:18:19 +01:00
"82816 79 0 3 2 \n",
"82817 80 20 3 4 \n",
"82818 82 0 3 4 \n",
"82819 84 21 3 4 \n",
"82820 88 8 1 2 \n",
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
2024-01-20 17:18:19 +01:00
"0 3 True False False \n",
"1 3 True False False \n",
"2 3 False False False \n",
"3 3 False False True \n",
"4 3 False False False \n",
"... ... ... ... ... \n",
2024-01-20 17:18:19 +01:00
"82816 3 True False False \n",
"82817 3 False False False \n",
"82818 3 True False False \n",
"82819 3 False False False \n",
"82820 3 False False False \n",
"\n",
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
2024-01-20 17:18:19 +01:00
"1 False False ... 20.6 \n",
"2 False False ... 29.0 \n",
"3 False False ... NaN \n",
2024-01-20 17:18:19 +01:00
"4 False False ... 27.9 \n",
"... ... ... ... ... \n",
2024-01-20 17:18:19 +01:00
"82816 False False ... 30.9 \n",
"82817 False False ... 60.2 \n",
"82818 False False ... NaN \n",
"82819 False False ... NaN \n",
"82820 False False ... NaN \n",
"\n",
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
2024-01-20 17:18:19 +01:00
"1 32.8 NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
2024-01-20 17:18:19 +01:00
"4 31.4 33.4 NaN \n",
"... ... ... ... \n",
2024-01-20 17:18:19 +01:00
"82816 NaN NaN NaN \n",
"82817 NaN NaN NaN \n",
"82818 NaN NaN NaN \n",
"82819 NaN NaN NaN \n",
"82820 NaN NaN NaN \n",
"\n",
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
2024-01-20 17:18:19 +01:00
"0 NaN NaN 26.6 \n",
"1 NaN NaN 23.8 \n",
"2 NaN NaN 29.6 \n",
"3 NaN NaN 26.7 \n",
"4 NaN NaN 16.9 \n",
"... ... ... ... \n",
2024-01-20 17:18:19 +01:00
"82816 NaN NaN 30.8 \n",
"82817 NaN NaN 31.9 \n",
"82818 NaN NaN NaN \n",
"82819 NaN NaN NaN \n",
"82820 NaN NaN 20.0 \n",
"\n",
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
2024-01-20 17:18:19 +01:00
"0 53.1 NaN \n",
"1 31.2 NaN \n",
"2 55.3 NaN \n",
"3 60.4 NaN \n",
"4 40.1 NaN \n",
"... ... ... \n",
2024-01-20 17:18:19 +01:00
"82816 40.3 NaN \n",
"82817 47.7 NaN \n",
"82818 NaN NaN \n",
"82819 NaN NaN \n",
"82820 44.5 NaN \n",
"\n",
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"... ... \n",
2024-01-20 17:18:19 +01:00
"82816 NaN \n",
"82817 NaN \n",
"82818 NaN \n",
"82819 NaN \n",
"82820 NaN \n",
"\n",
2024-01-20 17:18:19 +01:00
"[82821 rows x 64 columns]"
]
},
2024-01-20 17:18:19 +01:00
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['labelEncoder.joblib']"
]
},
2024-01-20 17:18:19 +01:00
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dump(enc,'labelEncoder.joblib')"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"enc2 = load('labelEncoder.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# enc.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# ############### NEW ################\n",
"# from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# le_posiotion_name = LabelEncoder()\n",
"# le_shot_technique_name = LabelEncoder()\n",
"# le_shot_type_name = LabelEncoder()\n",
"# le_shot_body_part_name = LabelEncoder()\n",
"\n",
"# df['position_name'] = le_posiotion_name.fit_transform(df['position_name'])\n",
"# df['shot_technique_name'] = le_shot_technique_name.fit_transform(df['shot_technique_name'])\n",
"# df['shot_type_name'] = le_shot_type_name.fit_transform(df['shot_type_name'])\n",
"# df['shot_body_part_name'] = le_shot_body_part_name.fit_transform(df['shot_body_part_name'])"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Change the type of categorical features to 'category' \n",
2024-01-20 17:18:19 +01:00
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
2024-01-20 17:18:19 +01:00
" 'shot_body_part_name']] = df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']].astype('category')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 11,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Splitting the dataset into features (X) and the target variable (y)\n",
"y = pd.DataFrame(df['is_goal'])\n",
"X = df.drop(['is_goal'], axis=1)\n",
"\n",
"# Splitting the data into a training set and a test set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
"\n",
"# Create cross-validation \n",
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 12,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2023-12-12 15:22:01 +01:00
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
"Shots attempted in the training set: 58970\n",
"Goals scored in the training set: 7286\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"\n",
"# Display the count of shots attempted in the training set\n",
"print('Shots attempted in the training set:', count_class_0)\n",
"\n",
"# Display the count of successful goals in the training set\n",
"print('Goals scored in the training set:', count_class_1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 13,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
" Class imbalance in training data: 8.094\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Class imbalance in training data\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training XGBoost model "
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 14,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 15,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
2024-01-20 12:55:53 +01:00
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
" 'max_depth': [3, 5, 7, 8, 9],\n",
" 'n_estimators': [100, 150, 200, 250, 300],\n",
2023-12-12 15:22:01 +01:00
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 16,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Starting the timer\n",
"start_time = time.time()\n",
"\n",
"# Perform grid search with cross-validation\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"\n",
"# Take the best parameters for xgboost model\n",
"best_xgb_model = grid_xg.best_estimator_\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 17,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 17:18:19 +01:00
"Best parameters: {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
"Model Training Time: 1393.345 seconds\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model evaluation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 18,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 17:18:19 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABLZUlEQVR4nO3deVxUZf//8feAMiAIiAtI7plb7huSqZkoGZqmllZ34VZpaClq6p2paUVZ9+2SW6tYd5ZaaS6pEaZmUipGuaelWSGoFaIoi3B+f/hjvo6ggjNHRn09v4/5PppzrrnOdQ7y/b75XNc5YzEMwxAAAIALcyvpAQAAAFwJgQUAALg8AgsAAHB5BBYAAODyCCwAAMDlEVgAAIDLI7AAAACXR2ABAAAuj8ACAABcHoEF140DBw6oS5cu8vPzk8Vi0fLly53a/+HDh2WxWBQbG+vUfq9nd911l+66666SHoYp+HkD1xcCC4rll19+0ZNPPqlatWrJ09NTvr6+atu2rWbOnKmzZ8+aeuzIyEjt3LlTL730kj744AO1bNnS1ONdS/3795fFYpGvr2+h1/HAgQOyWCyyWCx6/fXXi91/cnKyJk+erKSkJCeM1lyTJ0+2nevlXq4YpPLy8vT+++8rJCREAQEBKlu2rOrUqaPHHntM3333XbH7O3PmjCZPnqwNGzY4f7DAdaZUSQ8A14/Vq1frgQcekNVq1WOPPaaGDRsqOztbmzdv1pgxY7R792699dZbphz77NmzSkhI0HPPPadhw4aZcozq1avr7NmzKl26tCn9X0mpUqV05swZrVy5Ug8++KDdvg8//FCenp7KzMy8qr6Tk5P1wgsvqEaNGmratGmRP/fll19e1fEc0atXL9WuXdv2/vTp0xo6dKjuv/9+9erVy7Y9MDDQoeOY8fN++umnNWfOHPXo0UOPPPKISpUqpf3792vNmjWqVauW2rRpU6z+zpw5oxdeeEGSXDKgAdcSgQVFcujQIfXr10/Vq1fX+vXrVblyZdu+qKgoHTx4UKtXrzbt+MePH5ck+fv7m3YMi8UiT09P0/q/EqvVqrZt2+qjjz4qEFgWLVqkiIgIffrpp9dkLGfOnFGZMmXk4eFxTY53ocaNG6tx48a29ydOnNDQoUPVuHFj/etf/7rk5zIzM+Xh4SE3t6IVjp39805NTdXcuXP1+OOPFwjuM2bMsP0bBnB1mBJCkUybNk2nT5/Wu+++axdW8tWuXVvPPPOM7f25c+c0depU3XrrrbJarapRo4b+/e9/Kysry+5zNWrUULdu3bR582a1bt1anp6eqlWrlt5//31bm8mTJ6t69eqSpDFjxshisahGjRqSzk+l5P/3hfKnFS4UFxenO++8U/7+/vLx8VHdunX173//27b/Umsa1q9fr3bt2snb21v+/v7q0aOH9u7dW+jxDh48qP79+8vf319+fn4aMGCAzpw5c+kLe5GHH35Ya9asUVpamm3btm3bdODAAT388MMF2v/9998aPXq0GjVqJB8fH/n6+qpr16768ccfbW02bNigVq1aSZIGDBhgm1LJP8+77rpLDRs2VGJiotq3b68yZcrYrsvFa1giIyPl6elZ4PzDw8NVrlw5JScnF/lcHbFhwwZZLBZ9/PHHmjBhgm655RaVKVNG6enpRbomUuE/7/79+8vHx0d//vmnevbsKR8fH1WsWFGjR49Wbm7uZcd06NAhGYahtm3bFthnsVhUqVIlu21paWkaMWKEqlatKqvVqtq1a+vVV19VXl6ebXwVK1aUJL3wwgu2n9vkyZOv4ooB1z8qLCiSlStXqlatWrrjjjuK1H7w4MFauHCh+vTpo1GjRun7779XTEyM9u7dq2XLltm1PXjwoPr06aNBgwYpMjJS7733nvr3768WLVro9ttvV69eveTv76+RI0fqoYce0r333isfH59ijX/37t3q1q2bGjdurClTpshqtergwYP69ttvL/u5r776Sl27dlWtWrU0efJknT17Vm+88Ybatm2rHTt2FAhLDz74oGrWrKmYmBjt2LFD77zzjipVqqRXX321SOPs1auXhgwZos8++0wDBw6UdL66Uq9ePTVv3rxA+19//VXLly/XAw88oJo1ayo1NVVvvvmmOnTooD179ig4OFj169fXlClTNHHiRD3xxBNq166dJNn9LP/66y917dpV/fr107/+9a9LTrfMnDlT69evV2RkpBISEuTu7q4333xTX375pT744AMFBwcX6TydZerUqfLw8NDo0aOVlZUlDw8P7dmz54rX5HJyc3MVHh6ukJAQvf766/rqq6/0n//8R7feequGDh16yc/lh+qlS5fqgQceUJkyZS7Z9syZM+rQoYP+/PNPPfnkk6pWrZq2bNmi8ePH6+jRo5oxY4YqVqyoefPmFZgOu7D6BNxUDOAKTp48aUgyevToUaT2SUlJhiRj8ODBdttHjx5tSDLWr19v21a9enVDkrFp0ybbtmPHjhlWq9UYNWqUbduhQ4cMScZrr71m12dkZKRRvXr1AmOYNGmSceE/7+nTpxuSjOPHj19y3PnHWLBggW1b06ZNjUqVKhl//fWXbduPP/5ouLm5GY899liB4w0cONCuz/vvv98oX778JY954Xl4e3sbhmEYffr0MTp16mQYhmHk5uYaQUFBxgsvvFDoNcjMzDRyc3MLnIfVajWmTJli27Zt27YC55avQ4cOhiRj/vz5he7r0KGD3bZ169YZkowXX3zR+PXXXw0fHx+jZ8+eVzzHq3X8+HFDkjFp0iTbtq+//tqQZNSqVcs4c+aMXfuiXpPCft6RkZGGJLt2hmEYzZo1M1q0aHHFsT722GOGJKNcuXLG/fffb7z++uvG3r17C7SbOnWq4e3tbfz8889228eNG2e4u7sbR44cueS5AzcrpoRwRenp6ZKksmXLFqn9F198IUmKjo622z5q1ChJKrDWpUGDBra/+iWpYsWKqlu3rn799derHvPF8te+fP7557aS+5UcPXpUSUlJ6t+/vwICAmzbGzdurM6dO9vO80JDhgyxe9+uXTv99ddftmtYFA8//LA2bNiglJQUrV+/XikpKYVOB0nn173kr9nIzc3VX3/9ZZvu2rFjR5GPabVaNWDAgCK17dKli5588klNmTJFvXr1kqenp958880iH8uZIiMj5eXlZbfNGdeksJ9jUf49LliwQLNnz1bNmjW1bNkyjR49WvXr11enTp30559/2totXbpU7dq1U7ly5XTixAnbKywsTLm5udq0aVORxgncTAgsuCJfX19J0qlTp4rU/rfffpObm5vdnR6SFBQUJH9/f/32229226tVq1agj3Llyumff/65yhEX1LdvX7Vt21aDBw9WYGCg+vXrpyVLllw2vOSPs27dugX21a9fXydOnFBGRobd9ovPpVy5cpJUrHO59957VbZsWS1evFgffvihWrVqVeBa5svLy9P06dN12223yWq1qkKFCqpYsaJ++uknnTx5ssjHvOWWW4q1wPb1119XQECAkpKSNGvWrALrMwpz/PhxpaSk2F6nT58u8vEupWbNmgW2OXpNPD09bWtH8hX136Obm5uioqKUmJioEydO6PPPP1fXrl21fv169evXz9buwIEDWrt2rSpWrGj3CgsLkyQdO3bsiscCbjYEFlyRr6+vgoODtWvXrmJ97uJFr5fi7u5e6HbDMK76GBcvkPTy8tKmTZv01Vdf6dFHH9VPP/2kvn37qnPnzldcTFkcjpxLPqvVql69emnhwoVatmzZJasrkvTyyy8rOjpa7du31//+9z+tW7dOcXFxuv3224tcSZJUoEpxJT/88IPt/6nu3LmzSJ9p1aqVKleubHtdzfNkLlbYuB29Jpf6GRZX+fLldd999+mLL75Qhw4dtHnzZlsIzsvLU+fOnRUXF1foq3fv3k4ZA3AjYdEtiqRbt2566623lJCQoNDQ0Mu2rV69uvLy8nTgwAHVr1/ftj01NVVpaWm2xYnOUK5cObs7avJdXMWRzv/126lTJ3Xq1En//e9/9fLLL+u5557T119/bfvL9uLzkKT9+/cX2Ldv3z5VqFB
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on training set\n",
"y_pred_train = best_xgb_model.predict(X_train)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 19,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 17:18:19 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABQAElEQVR4nO3de1yO9/8H8Nd9d7hLOkrlHpJTas6x5GyaHCeHWdjEwliNnEYz5NhkmzPNTmzDjE0zLJpGGyHR0GgOjTncJUmKztfvD7+ur2uFct1X3Xg9v4/r8fje1/W5Ptfnus326v25PlcqQRAEEBERERkwdVUPgIiIiOhxGFiIiIjI4DGwEBERkcFjYCEiIiKDx8BCREREBo+BhYiIiAweAwsREREZPAYWIiIiMngMLERERGTwGFjoqXXu3Dn06NED1tbWUKlUiIyM1Gv///zzD1QqFdavX6/Xfp9mXbt2RdeuXat6GET0HGJgIVkuXLiAt99+G/Xr14eZmRmsrKzQoUMHLF++HPfu3VP02v7+/jh16hQWLlyIb775Bm3atFH0epVp5MiRUKlUsLKyKvN7PHfuHFQqFVQqFT766KMK93/t2jWEhoYiMTFRD6NVVmhoqHivj9r0FaR2796N0NDQcrcvLi7G119/DU9PT9jZ2cHS0hKNGzfGiBEjcPjw4Qpf/+7duwgNDcX+/fsrfC7Rs8y4qgdAT69du3bhtddeg0ajwYgRI9C0aVPk5+fjjz/+wLRp05CUlIR169Ypcu179+4hLi4OM2fORFBQkCLXcHZ2xr1792BiYqJI/49jbGyMu3fv4ueff8aQIUMkxzZu3AgzMzPk5uY+Ud/Xrl3D3LlzUa9ePbRs2bLc5+3du/eJrifHwIED0bBhQ/FzdnY2xo8fjwEDBmDgwIHifkdHR71cb/fu3Vi9enW5Q8uECROwevVq9O/fH8OHD4exsTGSk5Pxyy+/oH79+mjXrl2Frn/37l3MnTsXAFjNInoAAws9kZSUFPj5+cHZ2RkxMTGoVauWeCwwMBDnz5/Hrl27FLv+jRs3AAA2NjaKXUOlUsHMzEyx/h9Ho9GgQ4cO2Lx5c6nAsmnTJvTp0wc//PBDpYzl7t27qFatGkxNTSvleg9q3rw5mjdvLn5OT0/H+PHj0bx5c7zxxhuVPp4HpaamYs2aNRgzZkypcL5s2TLxn1Miko9TQvREwsPDkZ2djS+++EISVko0bNgQEydOFD8XFhZi/vz5aNCgATQaDerVq4f3338feXl5kvPq1auHvn374o8//sBLL70EMzMz1K9fH19//bXYJjQ0FM7OzgCAadOmQaVSoV69egDuT6WU/P8HlUwrPCg6OhodO3aEjY0NqlevDldXV7z//vvi8Yc9wxITE4NOnTrBwsICNjY26N+/P86cOVPm9c6fP4+RI0fCxsYG1tbWGDVqFO7evfvwL/Y/hg0bhl9++QWZmZnivvj4eJw7dw7Dhg0r1T4jIwNTp05Fs2bNUL16dVhZWaFXr174888/xTb79+9H27ZtAQCjRo0Sp1RK7rNr165o2rQpEhIS0LlzZ1SrVk38Xv77DIu/vz/MzMxK3b+Pjw9sbW1x7dq1ct+rXGfPnsXgwYNhZ2cHMzMztGnTBjt27JC0KSgowNy5c9GoUSOYmZmhRo0a6NixI6KjowHc/+dn9erVACCZbnqYlJQUCIKADh06lDqmUqng4OAg2ZeZmYng4GDUqVMHGo0GDRs2xOLFi1FcXAzg/j9zNWvWBADMnTtXvH5FpqiInlWssNAT+fnnn1G/fn20b9++XO1Hjx6NDRs2YPDgwZgyZQqOHDmCsLAwnDlzBtu3b5e0PX/+PAYPHoyAgAD4+/vjyy+/xMiRI+Hh4YEXX3wRAwcOhI2NDSZNmoShQ4eid+/eqF69eoXGn5SUhL59+6J58+aYN28eNBoNzp8/j4MHDz7yvF9//RW9evVC/fr1ERoainv37mHlypXo0KEDjh8/XiosDRkyBC4uLggLC8Px48fx+eefw8HBAYsXLy7XOAcOHIhx48bhxx9/xFtvvQXgfnWlSZMmaN26dan2Fy9eRGRkJF577TW4uLggNTUVn376Kbp06YK//voLWq0Wbm5umDdvHmbPno2xY8eiU6dOACD5s7x58yZ69eoFPz8/vPHGGw+dblm+fDliYmLg7++PuLg4GBkZ4dNPP8XevXvxzTffQKvVlus+5UpKSkKHDh3wwgsvYMaMGbCwsMD3338PX19f/PDDDxgwYACA+0EyLCwMo0ePxksvvYSsrCwcO3YMx48fxyuvvIK3334b165dQ3R0NL755pvHXrckOG/duhWvvfYaqlWr9tC2d+/eRZcuXXD16lW8/fbbqFu3Lg4dOoSQkBBcv34dy5YtQ82aNbF27dpSU14PVpiInlsCUQXdvn1bACD079+/XO0TExMFAMLo0aMl+6dOnSoAEGJiYsR9zs7OAgAhNjZW3JeWliZoNBphypQp4r6UlBQBgLBkyRJJn/7+/oKzs3OpMcyZM0d48B/3pUuXCgCEGzduPHTcJdf46quvxH0tW7YUHBwchJs3b4r7/vzzT0GtVgsjRowodb233npL0ueAAQOEGjVqPPSaD96HhYWFIAiCMHjwYKF79+6CIAhCUVGR4OTkJMydO7fM7yA3N1coKioqdR8ajUaYN2+euC8+Pr7UvZXo0qWLAECIiIgo81iXLl0k+/bs2SMAEBYsWCBcvHhRqF69uuDr6/vYe3xSN27cEAAIc+bMEfd1795daNasmZCbmyvuKy4uFtq3by80atRI3NeiRQuhT58+j+w/MDBQqMi/GkeMGCEAEGxtbYUBAwYIH330kXDmzJlS7ebPny9YWFgIf//9t2T/jBkzBCMjI+Hy5csPvT8iEgROCVGFZWVlAQAsLS3L1X737t0AgMmTJ0v2T5kyBQBKPevi7u4u/tQPADVr1oSrqysuXrz4xGP+r5JnX3766SexHP84169fR2JiIkaOHAk7Oztxf/PmzfHKK6+I9/mgcePGST536tQJN2/eFL/D8hg2bBj2798PnU6HmJgY6HS6MqeDgPvPvajV9/9aFxUV4ebNm+J01/Hjx8t9TY1Gg1GjRpWrbY8ePfD2229j3rx5GDhwIMzMzPDpp5+W+1pyZWRkICYmBkOGDMGdO3eQnp6O9PR03Lx5Ez4+Pjh37hyuXr0K4P6fe1JSEs6dO6e363/11VdYtWoVXFxcsH37dkydOhVubm7o3r27eF3gfhWmU6dOsLW1FceYnp4Ob29vFBUVITY2Vm9jInoWMbBQhVlZWQEA7ty5U672ly5dglqtlqz0AAAnJyfY2Njg0qVLkv1169Yt1YetrS1u3br1hCMu7fXXX0eHDh0wevRoODo6ws/PD99///0jw0vJOF1dXUsdc3NzQ3p6OnJyciT7/3svtra2AFChe+nduzcsLS2xZcsWbNy4EW3bti31XZYoLi7G0qVL0ahRI2g0Gtjb26NmzZo4efIkbt++Xe5rvvDCCxV6wPajjz6CnZ0dEhMTsWLFilLPbpTlxo0b0Ol04padnV3u6z3o/PnzEAQBs2bNQs2aNSXbnDlzAABpaWkAgHnz5iEzMxONGzdGs2bNMG3aNJw8efKJrltCrVYjMDAQCQkJSE9Px08//YRevXohJiYGfn5+Yrtz584hKiqq1Bi9vb0lYySisvEZFqowKysraLVanD59ukLnPerhxQcZGRmVuV8QhCe+RlFRkeSzubk5YmNj8dtvv2HXrl2IiorCli1b8PLLL2Pv3r0PHUNFybmXEhqNBgMHDsSGDRtw8eLFRz6AuWjRIsyaNQtvvfUW5s+fDzs7O6jVagQHB5e7kgTc/34q4sSJE+J/cE+dOoWhQ4c+9py2bdtKwuqcOXOe6OHSkvuaOnUqfHx8ymxTEvA6d+6MCxcu4KeffsLevXvx+eefY+nSpYiIiMDo0aMrfO3/qlGjBl599VW8+uqr6Nq1Kw4cOIBLly7B2dkZxcXFeOWVV/Dee++VeW7jxo1lX5/oWcbAQk+kb9++WLduHeLi4uDl5fXItiX/sj537hzc3Nz
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on test set\n",
"y_pred_test = best_xgb_model.predict(X_test)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 20,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
"The test dataset contains 16565 shots, with 1891 of them being goals.\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Number of goals in test set\n",
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
2023-12-12 15:22:01 +01:00
"metadata": {},
"source": [
"## Feature importance"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 21,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 17:18:19 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1kAAAHHCAYAAABN8pIpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVyO2f/48dfdvicpKolIssVYE2NNyZI1E4OsYzB2slfWyJJ9jN3Yd2ZkCWUfTHaRfc/Ypaju6vz+6Nf1davIfMyQOc/H435wn+tc53pf59zd93Wu61znUgkhBJIkSZIkSZIkSdInofW5A5AkSZIkSZIkSfqayE6WJEmSJEmSJEnSJyQ7WZIkSZIkSZIkSZ+Q7GRJkiRJkiRJkiR9QrKTJUmSJEmSJEmS9AnJTpYkSZIkSZIkSdInJDtZkiRJkiRJkiRJn5DsZEmSJEmSJEmSJH1CspMlSZIkSZIkSZL0CclOliRJkiTlIcuWLUOlUnHr1q3PHYokSZKUA9nJkiRJkr5omZ2K7F7Dhg37R7Z59OhRgoKCePHixT9S/n/Z69evCQoKIioq6nOHIkmS9I/R+dwBSJIkSVJujB07lmLFimmklS1b9h/Z1tGjRwkODsbf3598+fL9I9v4uzp06MB3332Hvr7+5w7lb3n9+jXBwcEA1KlT5/MGI0mS9A+RnSxJkiQpT2jUqBGVK1f+3GH8TxITEzE2Nv6fytDW1kZbW/sTRfTvSU9PJyUl5XOHIUmS9K+QwwUlSZKkr8LOnTupVasWxsbGmJqa0rhxYy5evKiR59y5c/j7++Po6IiBgQGFChWiS5cuPH36VMkTFBTEkCFDAChWrJgyNPHWrVvcunULlUrFsmXLsmxfpVIRFBSkUY5KpSImJoZ27dphYWFBzZo1leUrV66kUqVKGBoakj9/fr777jvu3r37wf3M7p6sokWL0qRJE6KioqhcuTKGhoaUK1dOGZK3efNmypUrh4GBAZUqVeL06dMaZfr7+2NiYsKNGzfw9PTE2NgYW1tbxo4dixBCI29iYiKDBg3C3t4efX19nJ2dmTp1apZ8KpWKPn36sGrVKsqUKYO+vj4///wzVlZWAAQHByt1m1lvuWmft+v22rVrytVGc3NzOnfuzOvXr7PU2cqVK6latSpGRkZYWFjw7bffsmfPHo08ufn8SJIk5Za8kiVJkiTlCS9fvuTJkycaaQUKFADg119/pVOnTnh6ejJ58mRev37N/PnzqVmzJqdPn6Zo0aIAREREcOPGDTp37kyhQoW4ePEiv/zyCxcvXuSPP/5ApVLRsmVLrly5wpo1a5gxY4ayDSsrKx4/fvzRcbdp0wYnJycmTpyodEQmTJjA6NGj8fX1pVu3bjx+/JjZs2fz7bffcvr06b81RPHatWu0a9eOH374ge+//56pU6fStGlTfv75Z0aMGEGvXr0AmDRpEr6+vsTGxqKl9X/nWtPS0vDy8qJ69epMmTKFXbt2ERgYSGpqKmPHjgVACEGzZs2IjIyka9euVKhQgd27dzNkyBDu37/PjBkzNGLav38/69evp0+fPhQoUABXV1fmz5/Pjz/+SIsWLWjZsiUA5cuXB3LXPm/z9fWlWLFiTJo0iVOnTrFo0SKsra2ZPHmykic4OJigoCBq1KjB2LFj0dPT4/jx4+zfv5+GDRsCuf/8SJIk5ZqQJEmSpC/Y0qVLBZDtSwghXr16JfLlyye6d++usd7Dhw+Fubm5Rvrr16+zlL9mzRoBiIMHDyppoaGhAhA3b97UyHvz5k0BiKVLl2YpBxCBgYHK+8DAQAEIPz8/jXy3bt0S2traYsKECRrp58+fFzo6OlnSc6qPt2NzcHAQgDh69KiStnv3bgEIQ0NDcfv2bSV9wYIFAhCRkZFKWqdOnQQgfvrpJyUtPT1dNG7cWOjp6YnHjx8LIYTYunWrAMT48eM1YmrdurVQqVTi2rVrGvWhpaUlLl68qJH38ePHWeoqU27bJ7Nuu3TpopG3RYsWwtLSUnl/9epVoaWlJVq0aCHS0tI08qanpwshPu7zI0mSlFtyuKAkSZKUJ8ydO5eIiAiNF2Rc/Xjx4gV+fn48efJEeWlra1OtWjUiIyOVMgwNDZX/JyUl8eTJE6pXrw7AqVOn/pG4e/bsqfF+8+bNpKen4+vrqxFvoUKFcHJy0oj3Y5QuXRo3NzflfbVq1QCoV68eRYoUyZJ+48aNLGX06dNH+X/mcL+UlBT27t0LQHh4ONra2vTt21djvUGDBiGEYOfOnRrptWvXpnTp0rneh49tn3frtlatWjx9+pT4+HgAtm7dSnp6OmPGjNG4ape5f/Bxnx9JkqTcksMFJUmSpDyhatWq2U58cfXqVSCjM5EdMzMz5f/Pnj0jODiYtWvX8ujRI418L1++/ITR/p93Z0S8evUqQgicnJyyza+rq/u3tvN2RwrA3NwcAHt7+2zTnz9/rpGupaWFo6OjRlrJkiUBlPu/bt++ja2tLaamphr5XFxclOVve3ffP+Rj2+fdfbawsAAy9s3MzIzr16+jpaX13o7ex3x+JEmSckt2siRJkqQ8LT09Hci4r6ZQoUJZluvo/N9Pna+vL0ePHmXIkCFUqFABExMT0tPT8fLyUsp5n3fvCcqUlpaW4zpvX53JjFelUrFz585sZwk0MTH5YBzZyWnGwZzSxTsTVfwT3t33D/nY9vkU+/Yxnx9JkqTckt8ckiRJUp5WvHhxAKytrWnQoEGO+Z4/f86+ffsIDg5mzJgxSnrmlYy35dSZyrxS8u5Dit+9gvOheIUQFCtWTLlS9CVIT0/nxo0bGjFduXIFQJn4wcHBgb179/Lq1SuNq1mXL19Wln9ITnX7Me2TW8WLFyc9PZ2YmBgqVKiQYx748OdHkiTpY8h7siRJkqQ8zdPTEzMzMyZOnIharc6yPHNGwMyrHu9e5QgLC8uyTuazrN7tTJmZmVGgQAEOHjyokT5v3rxcx9uyZUu0tbUJDg7OEosQIst05f+mOXPmaMQyZ84cdHV1qV+/PgDe3t6kpaVp5AOYMWMGKpWKRo0afXAbRkZGQNa6/Zj2ya3mzZujpaXF2LFjs1wJy9xObj8/kiRJH0NeyZIkSZLyNDMzM+bPn0+HDh345ptv+O6777CysuLOnTvs2LEDd3d35syZg5mZGd9++y1TpkxBrVZjZ2fHnj17uHnzZpYyK1WqBMDIkSP57rvv0NXVpWnTphgbG9OtWzdCQkLo1q0blStX5uDBg8oVn9woXrw448ePZ/jw4dy6dYvmzZtjamrKzZs32bJlCz169GDw4MGfrH5yy8DAgF27dtGpUyeqVavGzp072bFjByNGjFCebdW0aVPq1q3LyJEjuXXrFq6uruzZs4dt27bRv39/5arQ+xgaGlK6dGnWrVtHyZIlyZ8/P2XLlqVs2bK5bp/cKlGiBCNHjmTcuHHUqlWLli1boq+vz8mTJ7G1tWXSpEm5/vxIkiR9DNnJkiRJkvK8du3aYWtrS0hICKGhoSQnJ2NnZ0etWrXo3Lmzkm/16tX89NNPzJ07FyEEDRs2ZOfOndja2mqUV6VKFcaNG8fPP//Mrl27SE9P5+bNmxgbGzNmzBgeP37Mxo0bWb9+PY0aNWLnzp1YW1vnOt5hw4ZRsmRJZsyYQXBwMJAxQUXDhg1p1qzZp6mUj6Strc2uXbv48ccfGTJkCKampgQGBmoM3dPS0mL79u2MGTOGdevWsXTpUooWLUpoaCiDBg3K9bYWLVrETz/9xIABA0hJSSEwMJCyZcvmun0+xtixYylWrBizZ89m5MiRGBkZUb58eTp06KDkye3nR5IkKbdU4t+481WSJEmSpC+Wv78/GzduJCEh4XOHIkmS9FWQ92RJkiRJkiRJkiR9QrKTJUmSJEmSJEmS9AnJTpYkSZIkSZIkSdInJO/JkiRJkiRJkiRJ+oTklSxJkiRJkiRJkqRPSHayJEmSJEmSJEmSPiH5nCxJ+o9LT0/nwYMHmJqaolKpPnc4kiRJkiTlghCCV69eYWtri5aWvG7ypZGdLEn6j3vw4AH29vafOwxJkiRJkv6Gu3fvUrhw4c8
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Gain\n",
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 22,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 17:18:19 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvwAAAHHCAYAAADDDYx8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxO6f/48dfdvqeSFG1IhSRblCVEy2js+yBjGUuIsQ5abNk1YZhN9jGLZXzsITvZDSJrGttkSaloPb8/+nW+bi1qxhKu5+PRg3POda5znXN1d1/nOte53gpJkiQEQRAEQRAEQfgoqbzvAgiCIAiCIAiC8PaIBr8gCIIgCIIgfMREg18QBEEQBEEQPmKiwS8IgiAIgiAIHzHR4BcEQRAEQRCEj5ho8AuCIAiCIAjCR0w0+AVBEARBEAThIyYa/IIgCIIgCILwERMNfkEQBEEQBEH4iIkGvyAIgiB8QFasWIFCoSA+Pv59F0UQhA+EaPALgiAIZVp+A7ewnwkTJryVYx49epSQkBCePn36VvL/lKWnpxMSEsL+/fvfd1EE4ZOh9r4LIAiCIAglMXXqVGxtbZXW1apV660c6+jRo4SGhuLv70+5cuXeyjH+rd69e9O9e3c0NTXfd1H+lfT0dEJDQwHw8PB4v4URhE+EaPALgiAIHwQfHx/q16//vovxn6SlpaGrq/uf8lBVVUVVVfUNlejdyc3NJTMz830XQxA+SWJIjyAIgvBR2LFjB02bNkVXVxd9fX0+++wzLl26pJTmr7/+wt/fnypVqqClpUXFihX58ssvefz4sZwmJCSEsWPHAmBraysPH4qPjyc+Ph6FQsGKFSsKHF+hUBASEqKUj0KhIDY2lp49e2JkZESTJk3k7WvWrKFevXpoa2tjbGxM9+7d+fvvv197noWN4bexsaFt27bs37+f+vXro62tjZOTkzxsZuPGjTg5OaGlpUW9evU4e/asUp7+/v7o6elx8+ZNvLy80NXVxcLCgqlTpyJJklLatLQ0vv76aywtLdHU1MTe3p558+YVSKdQKAgICGDt2rXUrFkTTU1Nli1bhqmpKQChoaHytc2/biWpn5ev7fXr1+WnMIaGhvTr14/09PQC12zNmjU0bNgQHR0djIyMaNasGbt371ZKU5LfH0H4UIkefkEQBOGDkJyczKNHj5TWlS9fHoDVq1fTt29fvLy8mD17Nunp6SxdupQmTZpw9uxZbGxsAIiKiuLmzZv069ePihUrcunSJX744QcuXbrE8ePHUSgUdOzYkatXr/LLL7+wcOFC+RimpqY8fPiw1OXu0qULdnZ2zJw5U24Uz5gxgylTptC1a1cGDBjAw4cPWbRoEc2aNePs2bP/ahjR9evX6dmzJ1999RVffPEF8+bNw8/Pj2XLlvHNN98wdOhQAMLCwujatStxcXGoqPxfv19OTg7e3t40atSIOXPmsHPnToKDg8nOzmbq1KkASJLE559/TnR0NP3796dOnTrs2rWLsWPHcvfuXRYuXKhUpn379vHbb78REBBA+fLlcXZ2ZunSpQwZMoQOHTrQsWNHAGrXrg2UrH5e1rVrV2xtbQkLC+PMmTP89NNPVKhQgdmzZ8tpQkNDCQkJwc3NjalTp6KhoUFMTAz79u2jTZs2QMl/fwThgyUJgiAIQhkWGRkpAYX+SJIkPXv2TCpXrpw0cOBApf0ePHggGRoaKq1PT08vkP8vv/wiAdLBgwfldXPnzpUA6datW0ppb926JQFSZGRkgXwAKTg4WF4ODg6WAKlHjx5K6eLj4yVVVVVpxowZSusvXLggqampFVhf1PV4uWzW1tYSIB09elRet2vXLgmQtLW1pdu3b8vrv//+ewmQoqOj5XV9+/aVAGn48OHyutzcXOmzzz6TNDQ0pIcPH0qSJEmbN2+WAGn69OlKZercubOkUCik69evK10PFRUV6dKlS0ppHz58WOBa5Stp/eRf2y+//FIpbYcOHSQTExN5+dq1a5KKiorUoUMHKScnRyltbm6uJEml+/0RhA+VGNIjCIIgfBCWLFlCVFSU0g/k9Qo/ffqUHj168OjRI/lHVVUVV1dXoqOj5Ty0tbXl/7948YJHjx7RqFEjAM6cOfNWyj148GCl5Y0bN5Kbm0vXrl2VyluxYkXs7OyUylsaNWrUoHHjxvKyq6srAC1btsTKyqrA+ps3bxbIIyAgQP5//pCczMxM9uzZA8D27dtRVVVlxIgRSvt9/fXXSJLEjh07lNY3b96cGjVqlPgcSls/r17bpk2b8vjxY1JSUgDYvHkzubm5BAUFKT3NyD8/KN3vjyB8qMSQHkEQBOGD0LBhw0Jf2r127RqQ17AtjIGBgfz/J0+eEBoayvr160lMTFRKl5yc/AZL+39enVno2rVrSJKEnZ1doenV1dX/1XFebtQDGBoaAmBpaVno+qSkJKX1KioqVKlSRWld9erVAeT3BW7fvo2FhQX6+vpK6RwdHeXtL3v13F+ntPXz6jkbGRkBeedmYGDAjRs3UFFRKfamozS/P4LwoRINfkEQBOGDlpubC+SNw65YsWKB7Wpq//dV17VrV44ePcrYsWOpU6cOenp65Obm4u3tLedTnFfHkOfLyckpcp+Xe63zy6tQKNixY0ehs+3o6em9thyFKWrmnqLWS6+8ZPs2vHrur1Pa+nkT51aa3x9B+FCJ32JBEAThg1a1alUAKlSogKenZ5HpkpKS2Lt3L6GhoQQFBcnr83t4X1ZUwz6/B/nVgFyv9my/rrySJGFrayv3oJcFubm53Lx5U6lMV69eBZBfWrW2tmbPnj08e/ZMqZf/ypUr8vbXKeralqZ+Sqpq1ark5uYSGxtLnTp1ikwDr//9EYQPmRjDLwiCIHzQvLy8MDAwYObMmWRlZRXYnj+zTn5v8Ku9v+Hh4QX2yZ8r/9WGvYGBAeXLl+fgwYNK67/77rsSl7djx46oqqoSGhpaoCySJBWYgvJdWrx4sVJZFi9ejLq6Oq1atQLA19eXnJwcpXQACxcuRKFQ4OPj89pj6OjoAAWvbWnqp6Tat2+PiooKU6dOLfCEIP84Jf39EYQPmejhFwRBED5oBgYGLF26lN69e1O3bl26d++OqakpCQkJbNu2DXd3dxYvXoyBgQHNmjVjzpw5ZGVlUalSJXbv3s2tW7cK5FmvXj0AJk2aRPfu3VFXV8fPzw9dXV0GDBjArFmzGDBgAPXr1+fgwYNyT3hJVK1alenTpzNx4kTi4+Np3749+vr63Lp1i02bNjFo0CDGjBnzxq5PSWlpabFz50769u2Lq6srO3bsYNu2bXzzzTfy3Pl+fn60aNGCSZMmER8fj7OzM7t37+bPP/8kMDBQ7i0vjra2NjVq1ODXX3+levXqGBsbU6tWLWrVqlXi+impatWqMWnSJKZNm0bTpk3p2LEjmpqanDx5EgsLC8LCwkr8+yMIH7T3NDuQIAiCIJRI/jSUJ0+eLDZddHS05OXlJRkaGkpaWlpS1apVJX9/f+nUqVNymjt37kgdOnSQypUrJxkaGkpdunSR7t27V+g0kdOmTZMqVaokqaioKE2DmZ6eLvXv318yNDSU9PX1pa5du0qJiYlFTsuZP6XlqzZs2CA1adJE0tXVlXR1dSUHBwdp2LBhUlxcXImux6vTcn722WcF0gLSsGHDlNblTy06d+5ceV3fvn0lXV1d6caNG1KbNm0kHR0dyczMTAoODi4wneWzZ8+kUaNGSRYWFpK6urpkZ2cnzZ07V57msrhj5zt69KhUr149SUNDQ+m6lbR+irq2hV0bSZKk5cuXSy4uLpKmpqZkZGQkNW/eXIqKilJKU5LfH0H4UCkk6R28tSMIgiAIQpnl7+/PH3/8QWpq6vsuiiAIb4EYwy8IgiAIgiAIHzHR4BcEQRAEQRCEj5ho8AuCIAiCIAjCR0yM4RcEQRAEQRCEj5jo4RcEQRAEQRCEj5ho8AuCIAiCIAjCR0wE3hKET1xubi737t1DX1+/yJD3giAIgiCULZIk8ezZMywsLFBRKb4PXzT4BeETd+/ePSwtLd93MQRBEARB+Bf+/vtvKleuXGwa0eAXhE+
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Weight\n",
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 23,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Calculating MAE, RMSE and R2 for training and test sets \n",
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
"r2_train = r2_score(y_train, y_pred_train)\n",
"\n",
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
"r2_test = r2_score(y_test, y_pred_test)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 24,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-01-20 17:18:19 +01:00
"#T_c2f8d_row0_col0, #T_c2f8d_row0_col1, #T_c2f8d_row0_col2, #T_c2f8d_row0_col3, #T_c2f8d_row0_col4, #T_c2f8d_row0_col5, #T_c2f8d_row0_col6 {\n",
2023-12-12 15:22:01 +01:00
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
2024-01-20 17:18:19 +01:00
"<table id=\"T_c2f8d\">\n",
2023-12-12 15:22:01 +01:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-01-20 17:18:19 +01:00
" <th id=\"T_c2f8d_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
" <th id=\"T_c2f8d_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
" <th id=\"T_c2f8d_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
" <th id=\"T_c2f8d_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
" <th id=\"T_c2f8d_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
" <th id=\"T_c2f8d_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
" <th id=\"T_c2f8d_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-01-20 17:18:19 +01:00
" <th id=\"T_c2f8d_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_c2f8d_row0_col0\" class=\"data row0 col0\" >0.09646</td>\n",
" <td id=\"T_c2f8d_row0_col1\" class=\"data row0 col1\" >0.31058</td>\n",
" <td id=\"T_c2f8d_row0_col2\" class=\"data row0 col2\" >0.01446</td>\n",
" <td id=\"T_c2f8d_row0_col3\" class=\"data row0 col3\" >0.09919</td>\n",
" <td id=\"T_c2f8d_row0_col4\" class=\"data row0 col4\" >0.31494</td>\n",
" <td id=\"T_c2f8d_row0_col5\" class=\"data row0 col5\" >0.01918</td>\n",
" <td id=\"T_c2f8d_row0_col6\" class=\"data row0 col6\" >23.22242</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-01-20 17:18:19 +01:00
"<pandas.io.formats.style.Styler at 0x15e925550>"
2023-12-12 15:22:01 +01:00
]
},
2024-01-20 17:18:19 +01:00
"execution_count": 24,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training MAE': mae_train, \n",
" 'Training RMSE': rmse_train,\n",
" 'Training R2':r2_train,\n",
" 'Testing MAE': mae_test, \n",
" 'Testing RMSE': rmse_test,\n",
" 'Testing R2':r2_test,\n",
" 'Training Time (mins)': xgb_training_time/60})\n",
2023-12-12 15:22:01 +01:00
"summary_df.set_index('Model Name', inplace=True)\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Displaying summary of results\n",
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
2023-12-12 15:22:01 +01:00
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keeping the xgboost model"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'objective': 'binary:logistic',\n",
" 'base_score': None,\n",
" 'booster': None,\n",
" 'callbacks': None,\n",
" 'colsample_bylevel': None,\n",
" 'colsample_bynode': None,\n",
" 'colsample_bytree': None,\n",
" 'device': None,\n",
" 'early_stopping_rounds': None,\n",
" 'enable_categorical': True,\n",
" 'eval_metric': None,\n",
" 'feature_types': None,\n",
" 'gamma': None,\n",
" 'grow_policy': None,\n",
" 'importance_type': None,\n",
" 'interaction_constraints': None,\n",
" 'learning_rate': 0.01,\n",
" 'max_bin': None,\n",
" 'max_cat_threshold': None,\n",
" 'max_cat_to_onehot': None,\n",
" 'max_delta_step': None,\n",
2024-01-20 12:55:53 +01:00
" 'max_depth': 5,\n",
" 'max_leaves': None,\n",
" 'min_child_weight': None,\n",
" 'missing': nan,\n",
" 'monotone_constraints': None,\n",
" 'multi_strategy': None,\n",
2024-01-20 17:18:19 +01:00
" 'n_estimators': 300,\n",
" 'n_jobs': None,\n",
" 'num_parallel_tree': None,\n",
" 'random_state': None,\n",
" 'reg_alpha': None,\n",
" 'reg_lambda': None,\n",
" 'sampling_method': None,\n",
" 'scale_pos_weight': 1,\n",
" 'subsample': None,\n",
" 'tree_method': 'hist',\n",
" 'validate_parameters': None,\n",
" 'verbosity': None}"
]
},
2024-01-20 17:18:19 +01:00
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the model\n",
"best_xgb_model.save_model('xgboost.json')\n",
"\n",
"best_xgb_model.get_params()"
]
},
{
"cell_type": "code",
"execution_count": 69,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"# dump(best_xgb_model, 'xgboost.joblib') \n",
2023-12-12 15:22:01 +01:00
"\n",
"# # Load the model\n",
"# model = load('xgboost.joblib')"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-01-20 17:18:19 +01:00
"array([[0.9021414 , 0.09785861],\n",
" [0.9434396 , 0.05656038],\n",
" [0.9602713 , 0.0397287 ],\n",
" ...,\n",
2024-01-20 17:18:19 +01:00
" [0.8207403 , 0.17925973],\n",
" [0.88015527, 0.11984473],\n",
" [0.9733864 , 0.02661358]], dtype=float32)"
]
},
2024-01-20 12:55:53 +01:00
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_xgb_model.predict_proba(X)"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"new_xgb_model = xgboost.XGBClassifier()\n",
"new_xgb_model.load_model('xgboost.json')"
]
},
{
"cell_type": "code",
2024-01-20 17:18:19 +01:00
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-01-20 17:18:19 +01:00
"array([[0.9021414 , 0.09785861],\n",
" [0.9434396 , 0.05656038],\n",
" [0.9602713 , 0.0397287 ],\n",
" ...,\n",
2024-01-20 17:18:19 +01:00
" [0.8207403 , 0.17925973],\n",
" [0.88015527, 0.11984473],\n",
" [0.9733864 , 0.02661358]], dtype=float32)"
]
},
2024-01-20 17:18:19 +01:00
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_xgb_model.predict_proba(X)"
2023-12-12 15:22:01 +01:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-20 09:26:51 +01:00
"version": "3.11.6"
2023-12-12 15:22:01 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}