fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb

1465 lines
350 KiB
Plaintext
Raw Normal View History

2023-12-12 15:22:01 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 1,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
"import xgboost\n",
"import time\n",
2024-01-12 20:39:14 +01:00
"from joblib import dump, load\n",
"import os"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the data"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 2,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-01-20 09:26:51 +01:00
"df = pd.read_csv('final_data_new.csv')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 3,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
2024-01-12 20:39:14 +01:00
" 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n",
" 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n",
" 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n",
" 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n",
" 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n",
" 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n",
" 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n",
" 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n",
" 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n",
" 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n",
" 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n",
" 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n",
" 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n",
" 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n",
" 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n",
" 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n",
" 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n",
" 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n",
" 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n",
" 'y_player_teammate_Goalkeeper'],\n",
" dtype='object')"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
2023-12-12 15:22:01 +01:00
"source": [
"df.columns"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
2024-01-12 20:39:14 +01:00
" <th>shot_follows_dribble</th>\n",
2023-12-12 15:22:01 +01:00
" <th>...</th>\n",
2024-01-12 20:39:14 +01:00
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2024-01-20 09:26:51 +01:00
" <td>2</td>\n",
" <td>Left Center Forward</td>\n",
" <td>Right Foot</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Half Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>26.6</td>\n",
" <td>53.1</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Left Back</td>\n",
" <td>Left Foot</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>20.6</td>\n",
" <td>32.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>23.8</td>\n",
" <td>31.2</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
2024-01-20 09:26:51 +01:00
" <td>15</td>\n",
" <td>Left Center Forward</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>29.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>29.6</td>\n",
" <td>55.3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
2024-01-20 09:26:51 +01:00
" <td>16</td>\n",
" <td>Center Forward</td>\n",
" <td>Head</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>26.7</td>\n",
" <td>60.4</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
2024-01-20 09:26:51 +01:00
" <td>18</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>27.9</td>\n",
" <td>31.4</td>\n",
" <td>33.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>16.9</td>\n",
" <td>40.1</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-01-12 20:39:14 +01:00
"<p>5 rows × 64 columns</p>\n",
2023-12-12 15:22:01 +01:00
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
2024-01-20 09:26:51 +01:00
"0 2 Left Center Forward Right Foot Half Volley \n",
"1 5 Left Back Left Foot Volley \n",
"2 15 Left Center Forward Left Foot Normal \n",
"3 16 Center Forward Head Normal \n",
"4 18 Right Center Forward Right Foot Normal \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
2024-01-20 09:26:51 +01:00
"0 Open Play True False False \n",
"1 Open Play True False False \n",
"2 Open Play False False False \n",
"3 Open Play False False True \n",
"4 Open Play False False False \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
2024-01-20 09:26:51 +01:00
"1 False False ... 20.6 \n",
"2 False False ... 29.0 \n",
2024-01-12 20:39:14 +01:00
"3 False False ... NaN \n",
2024-01-20 09:26:51 +01:00
"4 False False ... 27.9 \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
2024-01-20 09:26:51 +01:00
"1 32.8 NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
2024-01-20 09:26:51 +01:00
"4 31.4 33.4 NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
2024-01-20 09:26:51 +01:00
"0 NaN NaN 26.6 \n",
"1 NaN NaN 23.8 \n",
"2 NaN NaN 29.6 \n",
"3 NaN NaN 26.7 \n",
"4 NaN NaN 16.9 \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
2024-01-20 09:26:51 +01:00
"0 53.1 NaN \n",
"1 31.2 NaN \n",
"2 55.3 NaN \n",
"3 60.4 NaN \n",
"4 40.1 NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"\n",
"[5 rows x 64 columns]"
2023-12-12 15:22:01 +01:00
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
2023-12-12 15:22:01 +01:00
"source": [
"# df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']] = df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']].astype(float)"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OrdinalEncoder\n",
"\n",
"enc = OrdinalEncoder()\n",
"\n",
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = enc.fit_transform(df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name',\n",
" 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']].astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
" <th>shot_follows_dribble</th>\n",
" <th>...</th>\n",
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>18</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>18</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>48.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>17</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38337</th>\n",
" <td>61</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>47.3</td>\n",
" <td>42.6</td>\n",
" <td>54.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.3</td>\n",
" <td>50.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38338</th>\n",
" <td>66</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>47.6</td>\n",
" <td>39.4</td>\n",
" <td>43.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>19.0</td>\n",
" <td>45.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38339</th>\n",
" <td>73</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>48.9</td>\n",
" <td>48.1</td>\n",
" <td>41.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.7</td>\n",
" <td>29.6</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38340</th>\n",
" <td>75</td>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>29.1</td>\n",
" <td>33.6</td>\n",
" <td>40.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.2</td>\n",
" <td>32.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38341</th>\n",
" <td>90</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>62.6</td>\n",
" <td>51.0</td>\n",
" <td>66.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>23.0</td>\n",
" <td>45.5</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>38342 rows × 64 columns</p>\n",
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
"0 0 18 3 4 \n",
"1 5 18 1 4 \n",
"2 5 4 3 2 \n",
"3 5 10 3 4 \n",
"4 5 17 1 4 \n",
"... ... ... ... ... \n",
"38337 61 3 3 4 \n",
"38338 66 0 3 4 \n",
"38339 73 3 1 6 \n",
"38340 75 13 1 4 \n",
"38341 90 3 3 4 \n",
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
"0 3 False False False \n",
"1 3 False False False \n",
"2 3 True False False \n",
"3 3 False False False \n",
"4 3 True False False \n",
"... ... ... ... ... \n",
"38337 3 True False False \n",
"38338 3 True False False \n",
"38339 3 True False False \n",
"38340 3 False False False \n",
"38341 3 False False False \n",
"\n",
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
"1 False False ... NaN \n",
"2 False False ... 48.9 \n",
"3 False False ... NaN \n",
"4 False False ... NaN \n",
"... ... ... ... ... \n",
"38337 False False ... 47.3 \n",
"38338 False False ... 47.6 \n",
"38339 False False ... 48.9 \n",
"38340 False False ... 29.1 \n",
"38341 False False ... 62.6 \n",
"\n",
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"38337 42.6 54.7 NaN \n",
"38338 39.4 43.1 NaN \n",
"38339 48.1 41.1 NaN \n",
"38340 33.6 40.9 NaN \n",
"38341 51.0 66.7 NaN \n",
"\n",
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"38337 NaN NaN 21.3 \n",
"38338 NaN NaN 19.0 \n",
"38339 NaN NaN 21.7 \n",
"38340 NaN NaN 21.2 \n",
"38341 NaN NaN 23.0 \n",
"\n",
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"38337 50.9 NaN \n",
"38338 45.8 NaN \n",
"38339 29.6 NaN \n",
"38340 32.4 NaN \n",
"38341 45.5 NaN \n",
"\n",
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"... ... \n",
"38337 NaN \n",
"38338 NaN \n",
"38339 NaN \n",
"38340 NaN \n",
"38341 NaN \n",
"\n",
"[38342 rows x 64 columns]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['labelEncoder.joblib']"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dump(enc,'labelEncoder.joblib')"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"enc2 = load('labelEncoder.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# enc.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# ############### NEW ################\n",
"# from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# le_posiotion_name = LabelEncoder()\n",
"# le_shot_technique_name = LabelEncoder()\n",
"# le_shot_type_name = LabelEncoder()\n",
"# le_shot_body_part_name = LabelEncoder()\n",
"\n",
"# df['position_name'] = le_posiotion_name.fit_transform(df['position_name'])\n",
"# df['shot_technique_name'] = le_shot_technique_name.fit_transform(df['shot_technique_name'])\n",
"# df['shot_type_name'] = le_shot_type_name.fit_transform(df['shot_type_name'])\n",
"# df['shot_body_part_name'] = le_shot_body_part_name.fit_transform(df['shot_body_part_name'])"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Change the type of categorical features to 'category' \n",
"df[['minute',\n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']] = df[['minute',\n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']].astype('category')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 10,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Splitting the dataset into features (X) and the target variable (y)\n",
"y = pd.DataFrame(df['is_goal'])\n",
"X = df.drop(['is_goal'], axis=1)\n",
"\n",
"# Splitting the data into a training set and a test set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
"\n",
"# Create cross-validation \n",
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 11,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2023-12-12 15:22:01 +01:00
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
"Shots attempted in the training set: 58970\n",
"Goals scored in the training set: 7286\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"\n",
"# Display the count of shots attempted in the training set\n",
"print('Shots attempted in the training set:', count_class_0)\n",
"\n",
"# Display the count of successful goals in the training set\n",
"print('Goals scored in the training set:', count_class_1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 12,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
" Class imbalance in training data: 8.094\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Class imbalance in training data\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training XGBoost model "
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 13,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 14,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
"param_grid_xgb = {'learning_rate': [0.01],\n",
" 'max_depth': [3],\n",
" 'n_estimators': [300],\n",
2023-12-12 15:22:01 +01:00
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 15,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Starting the timer\n",
"start_time = time.time()\n",
"\n",
"# Perform grid search with cross-validation\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"\n",
"# Take the best parameters for xgboost model\n",
"best_xgb_model = grid_xg.best_estimator_\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 16,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-12 20:39:14 +01:00
"Best parameters: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
2024-01-20 09:26:51 +01:00
"Model Training Time: 10.600 seconds\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model evaluation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 17,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 09:26:51 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABL80lEQVR4nO3de3zO9f/H8ee1sWtjthm2Wc7JKaeMZskpy9LIsUiHOVU0yobQAVFWOjg0kk5T35RUfLEiEb6yomnlEF9KrWKz0YxhY/v8/vDb9XW1YXNdH7vwuH9v1+3m+nze1/vz/nxG3+de7/fnc1kMwzAEAADgwtzKegAAAAAXQ2ABAAAuj8ACAABcHoEFAAC4PAILAABweQQWAADg8ggsAADA5RFYAACAyyOwAAAAl0dgwRVj79696tq1q3x9fWWxWLRs2TKn9v/bb7/JYrEoISHBqf1eyTp16qROnTqV9TBMwc8buLIQWFAqv/zyix555BHVq1dPnp6e8vHxUbt27TR79mydPHnS1GNHRUVp+/btev755/X++++rdevWph7vcho0aJAsFot8fHyKvY579+6VxWKRxWLRyy+/XOr+Dxw4oClTpiglJcUJozXXlClTbOd6oZcrBqmCggK99957Cg0Nlb+/vypVqqQGDRrowQcf1Lffflvq/k6cOKEpU6Zo/fr1zh8scIUpV9YDwJUjMTFRd999t6xWqx588EE1bdpUeXl52rRpk8aNG6edO3dqwYIFphz75MmTSkpK0lNPPaWRI0eacozatWvr5MmTKl++vCn9X0y5cuV04sQJrVixQvfcc4/dvg8++ECenp46derUJfV94MABPfvss6pTp45atmxZ4s99+eWXl3Q8R/Tp00f169e3vT9+/LhGjBih3r17q0+fPrbtgYGBDh3HjJ/3Y489prlz56pnz5667777VK5cOe3Zs0dffPGF6tWrp7Zt25aqvxMnTujZZ5+VJJcMaMDlRGBBiezfv18DBgxQ7dq1tW7dOlWvXt22Lzo6Wvv27VNiYqJpx8/IyJAk+fn5mXYMi8UiT09P0/q/GKvVqnbt2unDDz8sElgWLVqkyMhIffrpp5dlLCdOnFCFChXk4eFxWY53rubNm6t58+a295mZmRoxYoSaN2+u+++//7yfO3XqlDw8POTmVrLCsbN/3unp6Zo3b54eeuihIsF91qxZtr/DAC4NU0IokRkzZuj48eN6++237cJKofr16+vxxx+3vT9z5oymTZum66+/XlarVXXq1NGTTz6p3Nxcu8/VqVNH3bt316ZNm3TzzTfL09NT9erV03vvvWdrM2XKFNWuXVuSNG7cOFksFtWpU0fS2amUwj+fq3Ba4Vxr1qzRrbfeKj8/P3l7e6thw4Z68sknbfvPt6Zh3bp1at++vSpWrCg/Pz/17NlTP//8c7HH27dvnwYNGiQ/Pz/5+vpq8ODBOnHixPkv7D8MHDhQX3zxhbKysmzbtm7dqr1792rgwIFF2h85ckRjx45Vs2bN5O3tLR8fH3Xr1k0//vijrc369evVpk0bSdLgwYNtUyqF59mpUyc1bdpUycnJ6tChgypUqGC7Lv9cwxIVFSVPT88i5x8REaHKlSvrwIEDJT5XR6xfv14Wi0UfffSRnn76aV133XWqUKGCsrOzS3RNpOJ/3oMGDZK3t7f++usv9erVS97e3qpWrZrGjh2r/Pz8C45p//79MgxD7dq1K7LPYrEoICDAbltWVpZGjx6tmjVrymq1qn79+nrxxRdVUFBgG1+1atUkSc8++6zt5zZlypRLuGLAlY8KC0pkxYoVqlevnm655ZYStR82bJgWLlyofv36acyYMfruu+8UFxenn3/+WUuXLrVru2/fPvXr109Dhw5VVFSU3nnnHQ0aNEghISG68cYb1adPH/n5+SkmJkb33nuv7rzzTnl7e5dq/Dt37lT37t3VvHlzTZ06VVarVfv27dM333xzwc999dVX6tatm+rVq6cpU6bo5MmTeu2119SuXTtt27atSFi65557VLduXcXFxWnbtm166623FBAQoBdffLFE4+zTp4+GDx+uzz77TEOGDJF0trrSqFEjtWrVqkj7X3/9VcuWLdPdd9+tunXrKj09XW+88YY6duyoXbt2KTg4WI0bN9bUqVM1adIkPfzww2rfvr0k2f0sDx8+rG7dumnAgAG6//77zzvdMnv2bK1bt05RUVFKSkqSu7u73njjDX355Zd6//33FRwcXKLzdJZp06bJw8NDY8eOVW5urjw8PLRr166LXpMLyc/PV0REhEJDQ/Xyyy/rq6++0iuvvKLrr79eI0aMOO/nCkP1kiVLdPfdd6tChQrnbXvixAl17NhRf/31lx555BHVqlVLmzdv1sSJE3Xw4EHNmjVL1apV0+uvv15kOuzc6hNwTTGAizh69KghyejZs2eJ2qekpBiSjGHDhtltHzt2rCHJWLdunW1b7dq1DUnGxo0bbdsOHTpkWK1WY8yYMbZt+/fvNyQZL730kl2fUVFRRu3atYuMYfLkyca5f71nzpxpSDIyMjLOO+7CY7z77ru2bS1btjQCAgKMw4cP27b9+OOPhpubm/Hggw8WOd6QIUPs+uzdu7dRpUqV8x7z3POoWLGiYRiG0a9fP6NLly6GYRhGfn6+ERQUZDz77LPFXoNTp04Z+fn5Rc7DarUaU6dOtW3bunVrkXMr1LFjR0OSMX/+/GL3dezY0W7b6tWrDUnGc889Z/z666+Gt7e30atXr4ue46XKyMgwJBmTJ0+2bfv6668NSUa9evWMEydO2LUv6TUp7ucdFRVlSLJrZxiGcdNNNxkhISEXHeuDDz5oSDIqV65s9O7d23j55ZeNn3/+uUi7adOmGRUrVjT++9//2m2fMGGC4e7ubqSmpp733IFrFVNCuKjs7GxJUqVKlUrU/vPPP5ckxcbG2m0fM2aMJBVZ69KkSRPbb/2SVK1aNTVs2FC//vrrJY/5nwrXvvz73/+2ldwv5uDBg0pJSdGgQYPk7+9v2968eXPdfvvttvM81/Dhw+3et2/fXocPH7Zdw5IYOHCg1q9fr7S0NK1bt05paWnFTgdJZ9e9FK7ZyM/P1+HDh23TXdu2bSvxMa1WqwYPHlyitl27dtUjjzyiqVOnqk+fPvL09NQbb7xR4mM5U1RUlLy8vOy2OeOaFPdzLMnfx3fffVfx8fGqW7euli5dqrFjx6px48bq0qWL/vrrL1u7JUuWqH379qpcubIyMzNtr/DwcOXn52vjxo0lGidwLSGw4KJ8fHwkSceOHStR+99//11ubm52d3pIUlBQkPz8/PT777/bba9Vq1aRPipXrqy///77EkdcVP/+/dWuXTsNGzZMgYGBGjBggD7++OMLhpfCcTZs2LDIvsaNGyszM1M5OTl22/95LpUrV5akUp3LnXfeqUqVKmnx4sX64IMP1KZNmyLXslBBQYFmzpypG264QVarVVWrVlW1atX0008/6ejRoyU+5nXXXVeqBbYvv/yy/P39lZKSojlz5hRZn1GcjIwMpaWl2V7Hjx8v8fHOp27dukW2OXpNPD09bWtHCpX076Obm5uio6OVnJyszMxM/fvf/1a3bt20bt06DRgwwNZu7969WrVqlapVq2b3Cg8PlyQdOnTooscCrjUEFlyUj4+PgoODtWPHjlJ97p+LXs/H3d292O2GYVzyMf65QNLLy0sbN27UV199pQceeEA//fST+vfvr9tvv/2iiylLw5FzKWS1WtWnTx8tXLhQS5cuPW91RZKmT5+u2NhYdejQQf/617+0evVqrVmzRjfeeGOJK0mSilQpLuaHH36w/Z/q9u3bS/SZNm3aqHr16rbXpTxP5p+KG7ej1+R8P8PSqlKliu666y59/vnn6tixozZt2mQLwQUFBbr99tu1Zs2aYl99+/Z1yhiAqwmLblEi3bt314IFC5SUlKSwsLALtq1du7YKCgq0d+9eNW7c2LY9PT1dWVlZtsWJzlC5cmW7O2oK/bOKI5397bdLly7q0qWLXn31VU2fPl1PPfWUvv76a9tvtv88D0nas2dPkX27d+9W1apVVbF
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on training set\n",
"y_pred_train = best_xgb_model.predict(X_train)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 18,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 09:26:51 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABQ20lEQVR4nO3de1yP9/8/8Mf73eFdohMq7yE5peacLclxmhwnh1nYxMJYjZw1Qw7TZJszzWbL9sGMTTMsmkZDTtHQaA5hxrsklaLz9fvDr+vrWqFc76veeNw/t+t2+7xf1/N6Xa/r7bCn5+v1ulIJgiCAiIiIyICpq3oARERERE/ChIWIiIgMHhMWIiIiMnhMWIiIiMjgMWEhIiIig8eEhYiIiAweExYiIiIyeExYiIiIyOAxYSEiIiKDx4SFnlkXLlxAjx49YGVlBZVKhcjISL32f+XKFahUKkREROi132dZ165d0bVr16oeBhG9gJiwkCyXLl3Ce++9h4YNG8LMzAyWlpbw9PTE8uXLcf/+fUXv7efnhzNnzuDjjz/Gd999h3bt2il6v8o0cuRIqFQqWFpalvk9XrhwASqVCiqVCp9++mmF+79x4wZCQkKQkJCgh9EqKyQkRHzWxx36SqR2796NkJCQcscXFxfj22+/hbu7O2xtbVGjRg00bdoUI0aMwJEjRyp8/3v37iEkJAT79++v8LVEzzPjqh4APbt27dqFN998ExqNBiNGjEDz5s2Rn5+PgwcPYtq0aUhMTMS6desUuff9+/cRFxeHWbNmITAwUJF7ODo64v79+zAxMVGk/ycxNjbGvXv38Msvv2DIkCGScxs3boSZmRlyc3Ofqu8bN25g3rx5aNCgAVq3bl3u6/bu3ftU95Nj4MCBaNy4sfg5Ozsb48ePx4ABAzBw4ECx3d7eXi/32717N1avXl3upGXChAlYvXo1+vfvj+HDh8PY2BhJSUn49ddf0bBhQ7Rv375C97937x7mzZsHAKxmET2ECQs9leTkZPj6+sLR0RExMTGoU6eOeC4gIAAXL17Erl27FLv/rVu3AADW1taK3UOlUsHMzEyx/p9Eo9HA09MTmzdvLpWwbNq0CX369MGPP/5YKWO5d+8eqlWrBlNT00q538NatmyJli1bip/T0tIwfvx4tGzZEm+//Xalj+dhKSkpWLNmDcaMGVMqOV+2bJn4+5SI5OOUED2VsLAwZGdnY/369ZJkpUTjxo0xceJE8XNhYSEWLFiARo0aQaPRoEGDBvjwww+Rl5cnua5Bgwbo27cvDh48iFdffRVmZmZo2LAhvv32WzEmJCQEjo6OAIBp06ZBpVKhQYMGAB5MpZT8/4eVTCs8LDo6Gh07doS1tTWqV68OZ2dnfPjhh+L5R61hiYmJQadOnWBhYQFra2v0798f586dK/N+Fy9exMiRI2FtbQ0rKyuMGjUK9+7de/QX+x/Dhg3Dr7/+ioyMDLHt+PHjuHDhAoYNG1YqPj09HVOnTkWLFi1QvXp1WFpaolevXvjzzz/FmP379+OVV14BAIwaNUqcUil5zq5du6J58+aIj49H586dUa1aNfF7+e8aFj8/P5iZmZV6fm9vb9jY2ODGjRvlfla5zp8/j8GDB8PW1hZmZmZo164dduzYIYkpKCjAvHnz0KRJE5iZmaFmzZro2LEjoqOjATz4/bN69WoAkEw3PUpycjIEQYCnp2epcyqVCnZ2dpK2jIwMBAUFoV69etBoNGjcuDEWL16M4uJiAA9+z9WuXRsAMG/ePPH+FZmiInpescJCT+WXX35Bw4YN0aFDh3LFjx49Ghs2bMDgwYMxZcoUHD16FKGhoTh37hy2b98uib148SIGDx4Mf39/+Pn54euvv8bIkSPh5uaGl19+GQMHDoS1tTUmTZqEoUOHonfv3qhevXqFxp+YmIi+ffuiZcuWmD9/PjQaDS5evIhDhw499rrffvsNvXr1QsOGDRESEoL79+9j5cqV8PT0xMmTJ0slS0OGDIGTkxNCQ0Nx8uRJfPXVV7Czs8PixYvLNc6BAwdi3Lhx+Omnn/Duu+8CeFBdadasGdq2bVsq/vLly4iMjMSbb74JJycnpKSk4IsvvkCXLl3w119/QavVwsXFBfPnz8ecOXMwduxYdOrUCQAkv5a3b99Gr1694Ovri7fffvuR0y3Lly9HTEwM/Pz8EBcXByMjI3zxxRfYu3cvvvvuO2i12nI9p1yJiYnw9PTESy+9hJkzZ8LCwgI//PADfHx88OOPP2LAgAEAHiSSoaGhGD16NF599VVkZWXhxIkTOHnyJF5//XW89957uHHjBqKjo/Hdd9898b4lifPWrVvx5ptvolq1ao+MvXfvHrp06YJ///0X7733HurXr4/Dhw8jODgYN2/exLJly1C7dm2sXbu21JTXwxUmoheWQFRBmZmZAgChf//+5YpPSEgQAAijR4+WtE+dOlUAIMTExIhtjo6OAgAhNjZWbEtNTRU0Go0wZcoUsS05OVkAICxZskTSp5+fn+Do6FhqDHPnzhUe/u2+dOlSAYBw69atR4675B7ffPON2Na6dWvBzs5OuH37ttj2559/Cmq1WhgxYkSp+7377ruSPgcMGCDUrFnzkfd8+DksLCwEQRCEwYMHC927dxcEQRCKiooEBwcHYd68eWV+B7m5uUJRUVGp59BoNML8+fPFtuPHj5d6thJdunQRAAjh4eFlnuvSpYukbc+ePQIAYeHChcLly5eF6tWrCz4+Pk98xqd169YtAYAwd+5csa179+5CixYthNzcXLGtuLhY6NChg9CkSROxrVWrVkKfPn0e239AQIBQkb8aR4wYIQAQbGxshAEDBgiffvqpcO7cuVJxCxYsECwsLIS///5b0j5z5kzByMhIuHbt2iOfj4gEgVNCVGFZWVkAgBo1apQrfvfu3QCAyZMnS9qnTJkCAKXWuri6uor/6geA2rVrw9nZGZcvX37qMf9XydqXn3/+WSzHP8nNmzeRkJCAkSNHwtbWVmxv2bIlXn/9dfE5HzZu3DjJ506dOuH27dvid1gew4YNw/79+6HT6RATEwOdTlfmdBDwYN2LWv3gj3VRURFu374tTnedPHmy3PfUaDQYNWpUuWJ79OiB9957D/Pnz8fAgQNhZmaGL774otz3kis9PR0xMTEYMmQI7t69i7S0NKSlpeH27dvw9vbGhQsX8O+//wJ48OuemJiICxcu6O3+33zzDVatWgUnJyds374dU6dOhYuLC7p37y7eF3hQhenUqRNsbGzEMaalpcHLywtFRUWIjY3V25iInkdMWKjCLC0tAQB3794tV/zVq1ehVqslOz0AwMHBAdbW1rh69aqkvX79+qX6sLGxwZ07d55yxKW99dZb8PT0xOjRo2Fvbw9fX1/88MMPj01eSsbp7Oxc6pyLiwvS0tKQk5Mjaf/vs9jY2ABAhZ6ld+/eqFGjBrZs2YKNGzfilVdeKfVdliguLsbSpUvRpEkTaDQa1KpVC7Vr18bp06eRmZlZ7nu+9NJLFVpg++mnn8LW1hYJCQlYsWJFqbUbZbl16xZ0Op14ZGdnl/t+D7t48SIEQcDs2bNRu3ZtyTF37lwAQGpqKgBg/vz5yMjIQNOmTdGiRQtMmzYNp0+ffqr7llCr1QgICEB8fDzS0tLw888/o1evXoiJiYGvr68Yd+HCBURFRZUao5eXl2SMRFQ2rmGhCrO0tIRWq8XZs2crdN3jFi8+zMjIqMx2QRCe+h5FRUWSz+bm5oiNjcXvv/+OXbt2ISoqClu2bMFrr72GvXv3PnIMFSXnWUpoNBoMHDgQGzZswOXLlx+7AHPRokWYPXs23n33XSxYsAC2trZQq9UICgoqdyUJePD9VMSpU6fE/+CeOXMGQ4cOfeI1r7zyiiRZnTt37lMtLi15rqlTp8Lb27vMmJIEr3Pnzrh06RJ+/vln7N27F1999RWWLl2K8PBwjB49usL3/q+aNWvijTfewBtvvIGuXbviwIEDuHr1KhwdHVFcXIzXX38d06dPL/Papk2byr4/0fOMCQs9lb59+2LdunWIi4uDh4fHY2NL/rK+cOECXFxcxPa
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on test set\n",
"y_pred_test = best_xgb_model.predict(X_test)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 19,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
"The test dataset contains 16565 shots, with 1891 of them being goals.\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Number of goals in test set\n",
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
2023-12-12 15:22:01 +01:00
"metadata": {},
"source": [
"## Feature importance"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 20,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 09:26:51 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAHHCAYAAADONqsSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwO6/8/8Nd9t293KkmRlCVZCpEWZK3kJGtkK9v5cOxr1hb7ErIfa9m3gzhHZCuchGQXUSSSQ1GpTnV339fvj37Nt3HfLRznVLyfj8f90D1zzTXvmet2z/2eueYaAWOMgRBCCCGEEEJIlSGs7AAIIYQQQgghhPBRokYIIYQQQgghVQwlaoQQQgghhBBSxVCiRgghhBBCCCFVDCVqhBBCCCGEEFLFUKJGCCGEEEIIIVUMJWqEEEIIIYQQUsVQokYIIYQQQgghVQwlaoQQQgghhBBSxVCiRgghhFQjISEhEAgESEpKquxQCCGE/IsoUSOEEFKlFScm8l6zZ8/+V9Z57do1+Pv7IyMj41+p/0eWm5sLf39/REZGVnYohBBSpSlWdgCEEEJIRSxcuBCmpqa8ac2bN/9X1nXt2jUEBATA29sbNWrU+FfW8bWGDRuGQYMGQUVFpbJD+Sq5ubkICAgAAHTq1KlygyGEkCqMEjVCCCHVQo8ePdCmTZvKDuMfycnJgYaGxj+qQ0FBAQoKCt8oov+OVCpFQUFBZYdBCCHVBnV9JIQQ8l04c+YMOnToAA0NDWhpaaFnz5549OgRr8z9+/fh7e0NMzMzqKqqonbt2hg5ciTS09O5Mv7+/pg5cyYAwNTUlOtmmZSUhKSkJAgEAoSEhMisXyAQwN/fn1ePQCBAXFwcBg8eDB0dHbRv356bv2/fPlhbW0NNTQ26uroYNGgQXr16Ve52yrtHrX79+vjpp58QGRmJNm3aQE1NDS1atOC6Fx4/fhwtWrSAqqoqrK2tcefOHV6d3t7e0NTUxPPnz+Hs7AwNDQ0YGRlh4cKFYIzxyubk5GD69OkwNjaGiooKzM3NERgYKFNOIBBgwoQJ2L9/P5o1awYVFRX8+uuv0NfXBwAEBARw+7Z4v1WkfUru24SEBO6qp7a2NkaMGIHc3FyZfbZv3z7Y2NhAXV0dOjo66NixI86dO8crU5HPDyGE/JfoihohhJBqITMzE2lpabxpNWvWBADs3bsXXl5ecHZ2xooVK5Cbm4stW7agffv2uHPnDurXrw8AOH/+PJ4/f44RI0agdu3aePToEbZt24ZHjx7h+vXrEAgE6Nu3L54+fYqDBw9i7dq13Dr09fXx/v37L457wIABaNSoEZYuXcolM0uWLMGCBQvg4eGB0aNH4/3799iwYQM6duyIO3fufFV3y4SEBAwePBj/+9//MHToUAQGBsLNzQ2//vor5s6di19++QUAsGzZMnh4eCA+Ph5C4f+dr5VIJHBxcYGtrS1WrlyJs2fPws/PD4WFhVi4cCEAgDGGXr16ISIiAqNGjULLli0RHh6OmTNnIiUlBWvXruXFdOnSJRw5cgQTJkxAzZo1YWVlhS1btmDcuHHo06cP+vbtCwCwtLQEULH2KcnDwwOmpqZYtmwZbt++jR07dqBWrVpYsWIFVyYgIAD+/v6wt7fHwoULoaysjBs3buDSpUtwcnICUPHPDyGE/KcYIYQQUoUFBwczAHJfjDH26dMnVqNGDTZmzBjecm/fvmXa2tq86bm5uTL1Hzx4kAFgV65c4aatWrWKAWAvXrzglX3x4gUDwIKDg2XqAcD8/Py4935+fgwA8/T05JVLSkpiCgoKbMmSJbzpDx48YIqKijLTS9sfJWMzMTFhANi1a9e4aeHh4QwAU1NTYy9fvuSmb926lQFgERER3DQvLy8GgE2cOJGbJpVKWc+ePZmysjJ7//49Y4yx0NBQBoAtXryYF1P//v2ZQCBgCQkJvP0hFArZo0ePeGXfv38vs6+KVbR9ivftyJEjeWX79OnD9PT0uPfPnj1jQqGQ9enTh0kkEl5ZqVTKGPuyzw8hhPyXqOsjIYSQamHTpk04f/487wUUXYXJyMiAp6cn0tLSuJeCggLatWuHiIgIrg41NTXu77y8PKSlpcHW1hYAcPv27X8l7rFjx/LeHz9+HFKpFB4eHrx4a9eujUaNGvHi/RJNmzaFnZ0d975du3YAgC5duqBevXoy058/fy5Tx4QJE7i/i7suFhQU4MKFCwCAsLAwKCgoYNKkSbzlpk+fDsYYzpw5w5vu6OiIpk2bVngbvrR9Pt+3HTp0QHp6OrKysgAAoaGhkEql8PX15V09LN4+4Ms+P4QQ8l+iro+EEEKqBRsbG7mDiTx79gxAUUIij0gk4v7+8OEDAgICcOjQIbx7945XLjMz8xtG+38+H6ny2bNnYIyhUaNGcssrKSl91XpKJmMAoK2tDQAwNjaWO/3jx4+86UKhEGZmZrxpjRs3BgDufriXL1/CyMgIWlpavHIWFhbc/JI+3/byfGn7fL7NOjo6AIq2TSQSITExEUKhsMxk8Us+P4QQ8l+iRI0QQki1JpVKARTdZ1S7dm2Z+YqK/3eo8/DwwLVr1zBz5ky0bNkSmpqakEqlcHFx4eopy+f3SBWTSCSlLlPyKlFxvAKBAGfOnJE7eqOmpma5cchT2kiQpU1nnw3+8W/4fNvL86Xt8y227Us+P4QQ8l+ibx9CCCHVWoMGDQAAtWrVQrdu3Uot9/HjR1y8eBEBAQHw9fXlphdfUSmptISs+IrN5w/C/vxKUnnxMsZgamrKXbGqCqRSKZ4/f86L6enTpwDADaZhYmKCCxcu4NOnT7yrak+ePOHml6e0ffsl7VNRDRo0gFQqRVxcHFq2bFlqGaD8zw8hhPzX6B41Qggh1ZqzszNEIhGWLl0KsVgsM794pMbiqy+fX20JCgqSWab4WWefJ2QikQg1a9bElStXeNM3b95c4Xj79u0LBQUFBAQEyMTCGJMZiv6/tHHjRl4sGzduhJKSErp27QoAcHV1hUQi4ZUDgLVr10IgEKBHjx7lrkNdXR2A7L79kvapqN69e0MoFGLhwoUyV+SK11PRzw8hhPzX6IoaIYSQak0kEmHLli0YNmwYWrdujUGDBkFfXx/Jyck4ffo0HBwcsHHjRohEInTs2BErV66EWCxGnTp1cO7cObx48UKmTmtrawDAvHnzMGjQICgpKcHNzQ0aGhoYPXo0li9fjtGjR6NNmza4cuUKd+WpIho0aIDFixdjzpw5SEpKQu/evaGlpYUXL17gxIkT+PnnnzFjxoxvtn8qSlVVFWfPnoWXlxfatWuHM2fO4PTp05g7dy737DM3Nzd07twZ8+bNQ1JSEqysrHDu3DmcPHkSU6ZM4a5OlUVNTQ1NmzbF4cOH0bhxY+jq6qJ58+Zo3rx5hdunoho2bIh58+Zh0aJF6NChA/r27QsVFRXExMTAyMgIy5Ytq/DnhxBC/muUqBFCCKn2Bg8eDCMjIyxfvhyrVq1Cfn4+6tSpgw4dOmDEiBFcuQMHDmDixInYtGkTGGNwcnLCmTNnYGRkxKuvbdu2WLRoEX799VecPXsWUqkUL168gIaGBnx9ffH+/Xv89ttvOHLkCHr06IEzZ86gVq1aFY539uzZaNy4MdauXYuAgAAARYN+ODk5oVevXt9mp3whBQUFnD17FuPGjcPMmTOhpaUFPz8/XjdEoVCIU6dOwdfXF4cPH0ZwcDDq16+PVatWYfr06RVe144dOzBx4kRMnToVBQUF8PPzQ/PmzSvcPl9i4cKFMDU1xYYNGzBv3jyoq6vD0tISw4YN48pU9PNDCCH/JQH7L+4mJoQQQkiV5e3tjd9++w3Z2dmVHQohhJD/j+5RI4QQQgghhJAqhhI1QgghhBBCCKliKFEjhBBCCCGEkCqG7lEjhBBCCCGEkCqGrqgRQgghhBBCSBVDiRohhBBCCCGEVDH0HDVCfnBSqRRv3ryBlpYWBAJBZYdDCCGEkApgjOHTp08wMjKCUEjXXr5HlKgR8oN78+YNjI2NKzsMQgghhHyFV69eoW7dupUdBvkXUKJGyA9OS0sLAPDixQvo6upWcjS
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Gain\n",
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 21,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 09:26:51 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAv8AAAHHCAYAAAAoOjd/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxO6f/48dfdqkVREpnIErIlOxkyUjRj7NnJkjE0JjuDlK3sjWWYjaxjNozPyBJjX0KWMZYQyaCxRRKt9++Pfp2vW4sQqvv9fDx66FznnOtc73Pfut/3Ode5LpVarVYjhBBCCCGEKPJ03ncDhBBCCCGEEO+GJP9CCCGEEEJoCUn+hRBCCCGE0BKS/AshhBBCCKElJPkXQgghhBBCS0jyL4QQQgghhJaQ5F8IIYQQQggtIcm/EEIIIYQQWkKSfyGEEEIIIbSEJP9CCCFEIRISEoJKpSI6Ovp9N0UIUQhJ8i+EEKJAy0x2s/uZMGHCWznm4cOH8ff35+HDh2+lfm2WmJiIv78/e/fufd9NEUIr6b3vBgghhBB5MW3aNCpWrKhRVqtWrbdyrMOHDxMQEICXlxclSpR4K8d4XX379qVHjx4YGhq+76a8lsTERAICAgBwcXF5v40RQgtJ8i+EEKJQaNeuHQ0aNHjfzXgjT548wcTE5I3q0NXVRVdXN59a9O6kp6eTnJz8vpshhNaTbj9CCCGKhG3btvHhhx9iYmJC8eLF+fjjjzl37pzGNn///TdeXl5UqlSJYsWKUaZMGQYOHMj9+/eVbfz9/Rk7diwAFStWVLoYRUdHEx0djUqlIiQkJMvxVSoV/v7+GvWoVCrOnz9Pr169KFmyJM2bN1fWr127lvr162NkZISFhQU9evTgxo0bL40zuz7/dnZ2fPLJJ+zdu5cGDRpgZGRE7dq1la41GzdupHbt2hQrVoz69etz6tQpjTq9vLwwNTXl6tWruLu7Y2Jigo2NDdOmTUOtVmts++TJE0aPHo2trS2GhoZUq1aNefPmZdlOpVLh4+PDunXrqFmzJoaGhixfvhwrKysAAgIClHObed7y8vo8f26vXLmi3J0xNzdnwIABJCYmZjlna9eupVGjRhgbG1OyZElatGjBzp07NbbJy/tHiKJArvwLIYQoFB49esS9e/c0ykqVKgXAmjVr6N+/P+7u7syePZvExESWLVtG8+bNOXXqFHZ2dgCEhYVx9epVBgwYQJkyZTh37hzfffcd586d4+jRo6hUKjp37sylS5f46aefWLhwoXIMKysr7t69+8rt7tatG/b29syaNUtJkGfOnMmUKVPw9PRk8ODB3L17l8WLF9OiRQtOnTr1Wl2Nrly5Qq9evfjss8/o06cP8+bNo3379ixfvpyvvvqKYcOGARAYGIinpyeRkZHo6PzfNcC0tDTatm1LkyZNmDNnDtu3b2fq1KmkpqYybdo0ANRqNZ9++il79uxh0KBB1K1blx07djB27Fhu3rzJwoULNdr0119/8csvv+Dj40OpUqVwdHRk2bJlfP7553Tq1InOnTsDUKdOHSBvr8/zPD09qVixIoGBgZw8eZIffviB0qVLM3v2bGWbgIAA/P39adasGdOmTcPAwIDw8HD++usv3NzcgLy/f4QoEtRCCCFEAbZy5Uo1kO2PWq1WP378WF2iRAm1t7e3xn6xsbFqc3NzjfLExMQs9f/0009qQL1//36lbO7cuWpAfe3aNY1tr127pgbUK1euzFIPoJ46daqyPHXqVDWg7tmzp8Z20dHRal1dXfXMmTM1ys+ePavW09PLUp7T+Xi+bRUqVFAD6sOHDytlO3bsUANqIyMj9fXr15Xyb7/9Vg2o9+zZo5T1799fDai/+OILpSw9PV398ccfqw0MDNR3795Vq9Vq9ebNm9WAesaMGRpt6tq1q1qlUqmvXLmicT50dHTU586d09j27t27Wc5Vpry+PpnnduDAgRrbdurUSW1paaksX758Wa2jo6Pu1KmTOi0tTWPb9PR0tVr9au8fIYoC6fYjhBCiUFi6dClhYWEaP5Bxtfjhw4f07NmTe/fuKT+6uro0btyYPXv2KHUYGRkpvz979ox79+7RpEkTAE6ePPlW2j106FCN5Y0bN5Keno6np6dGe8uUKYO9vb1Ge19FjRo1aNq0qbLcuHFjAD766CPKly+fpfzq1atZ6vDx8VF+z+y2k5yczK5duwAIDQ1FV1eXESNGaOw3evRo1Go127Zt0yhv2bIlNWrUyHMMr/r6vHhuP/zwQ+7fv098fDwAmzdvJj09HT8/P427HJnxwau9f4QoCqTbjxBCiEKhUaNG2T7we/nyZSAjyc2OmZmZ8vuDBw8ICAhgw4YN3LlzR2O7R48e5WNr/8+LIxRdvnwZtVqNvb19ttvr6+u/1nGeT/ABzM3NAbC1tc22PC4uTqNcR0eHSpUqaZRVrVoVQHm+4Pr169jY2FC8eHGN7RwcHJT1z3sx9pd51dfnxZhLliwJZMRmZmZGVFQUOjo6uX4BeZX3jxBFgST/QgghCrX09HQgo992mTJlsqzX0/u/jzpPT08OHz7M2LFjqVu3LqampqSnp9O2bVulnty82Oc8U1paWo77PH81O7O9KpWKbdu2ZTtqj6mp6UvbkZ2cRgDKqVz9wgO6b8OLsb/Mq74++RHbq7x/hCgK5B0thBCiUKtcuTIApUuXxtXVNcft4uLi2L17NwEBAfj5+SnlmVd+n5dTkp95ZfnFyb9evOL9svaq1WoqVqyoXFkvCNLT07l69apGmy5dugSgPPBaoUIFdu3axePHjzWu/l+8eFFZ/zI5ndtXeX3yqnLlyqSnp3P+/Hnq1q2b4zbw8vePEEWF9PkXQghRqLm7u2NmZsasWbNISUnJsj5zhJ7Mq8QvXhUODg7Osk/mWPwvJvlmZmaUKlWK/fv3a5R/8803eW5v586d0dXVJSAgIEtb1Gp1lmEt36UlS5ZotGXJkiXo6+vTunVrADw8PEhLS9PYDmDhwoWoVCratWv30mMYGxsDWc/tq7w+edWxY0d0dHSYNm1aljsHmcfJ6/tHiKJCrvwLIYQo1MzMzFi2bBl9+/alXr169OjRAysrK2JiYti6dSvOzs4sWbIEMzMzWrRowZw5c0hJSaFcuXLs3LmTa9euZamzfv36AEyaNIkePXqgr69P+/btMTExYfDgwQQFBTF48GAaNGjA/v37lSvkeVG5cmVmzJjBxIkTiY6OpmPHjhQvXpxr166xadMmhgwZwpgxY/Lt/ORVsWLF2L59O/3796dx48Zs27aNrVu38tVXXylj87dv355WrVoxadIkoqOjcXR0ZOfOnfzxxx/4+voqV9FzY2RkRI0aNfj555+pWrUqFhYW1KpVi1q1auX59cmrKlWqMGnSJKZPn86HH35I586dMTQ05Pjx49jY2BAYGJjn948QRcZ7GmVICCGEyJPMoS2PHz+e63Z79uxRu7u7q83NzdXFihVTV65cWe3l5aU+ceKEss2///6r7tSpk7pEiRJqc3Nzdbdu3dS3bt3KdujJ6dOnq8uVK6fW0dHRGFozMTFRPWjQILW5ubm6ePHiak9PT/WdO3dyHOozc5jMF/3+++/q5s2bq01MTNQmJibq6tWrq4cPH66OjIzM0/l4cajPjz/+OMu2gHr48OEaZZnDlc6dO1cp69+/v9rExEQdFRWldnNzUxsbG6utra3VU6dOzTJE5uPHj9UjR45U29jYqPX19dX29vbquXPnKkNn5nbsTIcPH1bXr19fbWBgoHHe8vr65HRuszs3arVavWLFCrWTk5Pa0NBQXbJkSXXLli3VYWFhGtvk5f0jRFGgUqvfwRM/QgghhCiwvLy8+O2330hISHjfTRFCvGXS518IIYQQQggtIcm/EEIIIYQQWkKSfyGEEEIIIbSE9PkXQgghhBBCS8iVfyGEEEIIIbSEJP9CCCGEEEJoCZnkSwgtl56ezq1btyhevDgqlep9N0cIIYQQeaBWq3n8+DE2Njbo6OT9er4k/0JouVu3bmFra/u+myGEEEKI13Djxg0++OCDPG8vyb8QWq548eIAXLt2DQsLi/f
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Weight\n",
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 22,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Calculating MAE, RMSE and R2 for training and test sets \n",
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
"r2_train = r2_score(y_train, y_pred_train)\n",
"\n",
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
"r2_test = r2_score(y_test, y_pred_test)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 23,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-01-20 09:26:51 +01:00
"#T_64bf4_row0_col0, #T_64bf4_row0_col1, #T_64bf4_row0_col2, #T_64bf4_row0_col3, #T_64bf4_row0_col4, #T_64bf4_row0_col5, #T_64bf4_row0_col6 {\n",
2023-12-12 15:22:01 +01:00
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
2024-01-20 09:26:51 +01:00
"<table id=\"T_64bf4\">\n",
2023-12-12 15:22:01 +01:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-01-20 09:26:51 +01:00
" <th id=\"T_64bf4_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
" <th id=\"T_64bf4_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
" <th id=\"T_64bf4_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
" <th id=\"T_64bf4_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
" <th id=\"T_64bf4_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
" <th id=\"T_64bf4_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
" <th id=\"T_64bf4_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-01-20 09:26:51 +01:00
" <th id=\"T_64bf4_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_64bf4_row0_col0\" class=\"data row0 col0\" >0.09726</td>\n",
" <td id=\"T_64bf4_row0_col1\" class=\"data row0 col1\" >0.31186</td>\n",
" <td id=\"T_64bf4_row0_col2\" class=\"data row0 col2\" >0.00629</td>\n",
" <td id=\"T_64bf4_row0_col3\" class=\"data row0 col3\" >0.09955</td>\n",
" <td id=\"T_64bf4_row0_col4\" class=\"data row0 col4\" >0.31551</td>\n",
" <td id=\"T_64bf4_row0_col5\" class=\"data row0 col5\" >0.01560</td>\n",
" <td id=\"T_64bf4_row0_col6\" class=\"data row0 col6\" >0.17666</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-01-20 09:26:51 +01:00
"<pandas.io.formats.style.Styler at 0x138923a50>"
2023-12-12 15:22:01 +01:00
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 23,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training MAE': mae_train, \n",
" 'Training RMSE': rmse_train,\n",
" 'Training R2':r2_train,\n",
" 'Testing MAE': mae_test, \n",
" 'Testing RMSE': rmse_test,\n",
" 'Testing R2':r2_test,\n",
" 'Training Time (mins)': xgb_training_time/60})\n",
2023-12-12 15:22:01 +01:00
"summary_df.set_index('Model Name', inplace=True)\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Displaying summary of results\n",
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
2023-12-12 15:22:01 +01:00
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keeping the xgboost model"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'objective': 'binary:logistic',\n",
" 'base_score': None,\n",
" 'booster': None,\n",
" 'callbacks': None,\n",
" 'colsample_bylevel': None,\n",
" 'colsample_bynode': None,\n",
" 'colsample_bytree': None,\n",
" 'device': None,\n",
" 'early_stopping_rounds': None,\n",
" 'enable_categorical': True,\n",
" 'eval_metric': None,\n",
" 'feature_types': None,\n",
" 'gamma': None,\n",
" 'grow_policy': None,\n",
" 'importance_type': None,\n",
" 'interaction_constraints': None,\n",
" 'learning_rate': 0.01,\n",
" 'max_bin': None,\n",
" 'max_cat_threshold': None,\n",
" 'max_cat_to_onehot': None,\n",
" 'max_delta_step': None,\n",
" 'max_depth': 3,\n",
" 'max_leaves': None,\n",
" 'min_child_weight': None,\n",
" 'missing': nan,\n",
" 'monotone_constraints': None,\n",
" 'multi_strategy': None,\n",
" 'n_estimators': 300,\n",
" 'n_jobs': None,\n",
" 'num_parallel_tree': None,\n",
" 'random_state': None,\n",
" 'reg_alpha': None,\n",
" 'reg_lambda': None,\n",
" 'sampling_method': None,\n",
" 'scale_pos_weight': 1,\n",
" 'subsample': None,\n",
" 'tree_method': 'hist',\n",
" 'validate_parameters': None,\n",
" 'verbosity': None}"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the model\n",
"best_xgb_model.save_model('xgboost.json')\n",
"\n",
"best_xgb_model.get_params()"
]
},
{
"cell_type": "code",
"execution_count": 69,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"# dump(best_xgb_model, 'xgboost.joblib') \n",
2023-12-12 15:22:01 +01:00
"\n",
"# # Load the model\n",
"# model = load('xgboost.joblib')"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-01-20 09:26:51 +01:00
"array([[0.88809854, 0.11190146],\n",
" [0.94282126, 0.05717874],\n",
" [0.9659079 , 0.03409215],\n",
" ...,\n",
2024-01-20 09:26:51 +01:00
" [0.8352859 , 0.16471413],\n",
" [0.87101626, 0.12898372],\n",
" [0.95800334, 0.04199664]], dtype=float32)"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_xgb_model.predict_proba(X)"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"new_xgb_model = xgboost.XGBClassifier()\n",
"new_xgb_model.load_model('xgboost.json')"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-01-20 09:26:51 +01:00
"array([[0.88809854, 0.11190146],\n",
" [0.94282126, 0.05717874],\n",
" [0.9659079 , 0.03409215],\n",
" ...,\n",
2024-01-20 09:26:51 +01:00
" [0.8352859 , 0.16471413],\n",
" [0.87101626, 0.12898372],\n",
" [0.95800334, 0.04199664]], dtype=float32)"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_xgb_model.predict_proba(X)"
2023-12-12 15:22:01 +01:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-20 09:26:51 +01:00
"version": "3.11.6"
2023-12-12 15:22:01 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}