fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb

1465 lines
384 KiB
Plaintext
Raw Normal View History

2023-12-12 15:22:01 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 1,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
"import xgboost\n",
"import time\n",
2024-01-12 20:39:14 +01:00
"from joblib import dump, load\n",
"import os"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the data"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 2,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-01-20 09:26:51 +01:00
"df = pd.read_csv('final_data_new.csv')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 3,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
2024-01-12 20:39:14 +01:00
" 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n",
" 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n",
" 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n",
" 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n",
" 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n",
" 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n",
" 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n",
" 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n",
" 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n",
" 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n",
" 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n",
" 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n",
" 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n",
" 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n",
" 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n",
" 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n",
" 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n",
" 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n",
" 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n",
" 'y_player_teammate_Goalkeeper'],\n",
" dtype='object')"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
2023-12-12 15:22:01 +01:00
"source": [
"df.columns"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
2024-01-12 20:39:14 +01:00
" <th>shot_follows_dribble</th>\n",
2023-12-12 15:22:01 +01:00
" <th>...</th>\n",
2024-01-12 20:39:14 +01:00
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2024-01-20 09:26:51 +01:00
" <td>2</td>\n",
" <td>Left Center Forward</td>\n",
" <td>Right Foot</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Half Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>26.6</td>\n",
" <td>53.1</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Left Back</td>\n",
" <td>Left Foot</td>\n",
2024-01-20 09:26:51 +01:00
" <td>Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>20.6</td>\n",
" <td>32.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>23.8</td>\n",
" <td>31.2</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
2024-01-20 09:26:51 +01:00
" <td>15</td>\n",
" <td>Left Center Forward</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>29.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>29.6</td>\n",
" <td>55.3</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
2024-01-20 09:26:51 +01:00
" <td>16</td>\n",
" <td>Center Forward</td>\n",
" <td>Head</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2024-01-20 09:26:51 +01:00
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>26.7</td>\n",
" <td>60.4</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
2024-01-20 09:26:51 +01:00
" <td>18</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
2024-01-20 09:26:51 +01:00
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-20 09:26:51 +01:00
" <td>27.9</td>\n",
" <td>31.4</td>\n",
" <td>33.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-20 09:26:51 +01:00
" <td>16.9</td>\n",
" <td>40.1</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-01-12 20:39:14 +01:00
"<p>5 rows × 64 columns</p>\n",
2023-12-12 15:22:01 +01:00
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
2024-01-20 09:26:51 +01:00
"0 2 Left Center Forward Right Foot Half Volley \n",
"1 5 Left Back Left Foot Volley \n",
"2 15 Left Center Forward Left Foot Normal \n",
"3 16 Center Forward Head Normal \n",
"4 18 Right Center Forward Right Foot Normal \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
2024-01-20 09:26:51 +01:00
"0 Open Play True False False \n",
"1 Open Play True False False \n",
"2 Open Play False False False \n",
"3 Open Play False False True \n",
"4 Open Play False False False \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
2024-01-20 09:26:51 +01:00
"1 False False ... 20.6 \n",
"2 False False ... 29.0 \n",
2024-01-12 20:39:14 +01:00
"3 False False ... NaN \n",
2024-01-20 09:26:51 +01:00
"4 False False ... 27.9 \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
2024-01-20 09:26:51 +01:00
"1 32.8 NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
2024-01-20 09:26:51 +01:00
"4 31.4 33.4 NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
2024-01-20 09:26:51 +01:00
"0 NaN NaN 26.6 \n",
"1 NaN NaN 23.8 \n",
"2 NaN NaN 29.6 \n",
"3 NaN NaN 26.7 \n",
"4 NaN NaN 16.9 \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
2024-01-20 09:26:51 +01:00
"0 53.1 NaN \n",
"1 31.2 NaN \n",
"2 55.3 NaN \n",
"3 60.4 NaN \n",
"4 40.1 NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"\n",
"[5 rows x 64 columns]"
2023-12-12 15:22:01 +01:00
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 4,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
2023-12-12 15:22:01 +01:00
"source": [
"# df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']] = df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']].astype(float)"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OrdinalEncoder\n",
"\n",
"enc = OrdinalEncoder()\n",
"\n",
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = enc.fit_transform(df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name',\n",
" 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']].astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
" <th>shot_follows_dribble</th>\n",
" <th>...</th>\n",
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>18</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>18</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>48.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>17</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38337</th>\n",
" <td>61</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>47.3</td>\n",
" <td>42.6</td>\n",
" <td>54.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.3</td>\n",
" <td>50.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38338</th>\n",
" <td>66</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>47.6</td>\n",
" <td>39.4</td>\n",
" <td>43.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>19.0</td>\n",
" <td>45.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38339</th>\n",
" <td>73</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>48.9</td>\n",
" <td>48.1</td>\n",
" <td>41.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.7</td>\n",
" <td>29.6</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38340</th>\n",
" <td>75</td>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>29.1</td>\n",
" <td>33.6</td>\n",
" <td>40.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.2</td>\n",
" <td>32.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38341</th>\n",
" <td>90</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>62.6</td>\n",
" <td>51.0</td>\n",
" <td>66.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>23.0</td>\n",
" <td>45.5</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>38342 rows × 64 columns</p>\n",
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
"0 0 18 3 4 \n",
"1 5 18 1 4 \n",
"2 5 4 3 2 \n",
"3 5 10 3 4 \n",
"4 5 17 1 4 \n",
"... ... ... ... ... \n",
"38337 61 3 3 4 \n",
"38338 66 0 3 4 \n",
"38339 73 3 1 6 \n",
"38340 75 13 1 4 \n",
"38341 90 3 3 4 \n",
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
"0 3 False False False \n",
"1 3 False False False \n",
"2 3 True False False \n",
"3 3 False False False \n",
"4 3 True False False \n",
"... ... ... ... ... \n",
"38337 3 True False False \n",
"38338 3 True False False \n",
"38339 3 True False False \n",
"38340 3 False False False \n",
"38341 3 False False False \n",
"\n",
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
"1 False False ... NaN \n",
"2 False False ... 48.9 \n",
"3 False False ... NaN \n",
"4 False False ... NaN \n",
"... ... ... ... ... \n",
"38337 False False ... 47.3 \n",
"38338 False False ... 47.6 \n",
"38339 False False ... 48.9 \n",
"38340 False False ... 29.1 \n",
"38341 False False ... 62.6 \n",
"\n",
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"38337 42.6 54.7 NaN \n",
"38338 39.4 43.1 NaN \n",
"38339 48.1 41.1 NaN \n",
"38340 33.6 40.9 NaN \n",
"38341 51.0 66.7 NaN \n",
"\n",
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"38337 NaN NaN 21.3 \n",
"38338 NaN NaN 19.0 \n",
"38339 NaN NaN 21.7 \n",
"38340 NaN NaN 21.2 \n",
"38341 NaN NaN 23.0 \n",
"\n",
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"38337 50.9 NaN \n",
"38338 45.8 NaN \n",
"38339 29.6 NaN \n",
"38340 32.4 NaN \n",
"38341 45.5 NaN \n",
"\n",
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"... ... \n",
"38337 NaN \n",
"38338 NaN \n",
"38339 NaN \n",
"38340 NaN \n",
"38341 NaN \n",
"\n",
"[38342 rows x 64 columns]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['labelEncoder.joblib']"
]
},
2024-01-20 12:55:53 +01:00
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dump(enc,'labelEncoder.joblib')"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"enc2 = load('labelEncoder.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# enc.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# ############### NEW ################\n",
"# from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# le_posiotion_name = LabelEncoder()\n",
"# le_shot_technique_name = LabelEncoder()\n",
"# le_shot_type_name = LabelEncoder()\n",
"# le_shot_body_part_name = LabelEncoder()\n",
"\n",
"# df['position_name'] = le_posiotion_name.fit_transform(df['position_name'])\n",
"# df['shot_technique_name'] = le_shot_technique_name.fit_transform(df['shot_technique_name'])\n",
"# df['shot_type_name'] = le_shot_type_name.fit_transform(df['shot_type_name'])\n",
"# df['shot_body_part_name'] = le_shot_body_part_name.fit_transform(df['shot_body_part_name'])"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Change the type of categorical features to 'category' \n",
"df[['minute',\n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']] = df[['minute',\n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']].astype('category')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 9,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Splitting the dataset into features (X) and the target variable (y)\n",
"y = pd.DataFrame(df['is_goal'])\n",
"X = df.drop(['is_goal'], axis=1)\n",
"\n",
"# Splitting the data into a training set and a test set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
"\n",
"# Create cross-validation \n",
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 10,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2023-12-12 15:22:01 +01:00
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
"Shots attempted in the training set: 58970\n",
"Goals scored in the training set: 7286\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"\n",
"# Display the count of shots attempted in the training set\n",
"print('Shots attempted in the training set:', count_class_0)\n",
"\n",
"# Display the count of successful goals in the training set\n",
"print('Goals scored in the training set:', count_class_1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 11,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
" Class imbalance in training data: 8.094\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Class imbalance in training data\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training XGBoost model "
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 12,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 14,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
2024-01-20 12:55:53 +01:00
"param_grid_xgb = {'learning_rate': [0.01, 0.001, 0.0001],\n",
" 'max_depth': [3, 5, 7, 8, 9],\n",
" 'n_estimators': [100, 150, 200, 250, 300],\n",
2023-12-12 15:22:01 +01:00
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 15,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Starting the timer\n",
"start_time = time.time()\n",
"\n",
"# Perform grid search with cross-validation\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"\n",
"# Take the best parameters for xgboost model\n",
"best_xgb_model = grid_xg.best_estimator_\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 16,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 12:55:53 +01:00
"Best parameters: {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 250, 'scale_pos_weight': 1}\n",
"Model Training Time: 1550.563 seconds\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model evaluation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 17,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 12:55:53 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABLhklEQVR4nO3de3zP9f//8ft7Y+/NjuawWc7klFMOzVJJlqWRY9FxThWNYggdEGWSCiHpNPVNoaKiSBMqK5pWjotyiNlMmrWxje31+8Nv7493Gzbv98veuF2/l/f34v16Pd/P1/P1Gn3uezyfr9fbYhiGIQAAABfmVtYDAAAAuBACCwAAcHkEFgAA4PIILAAAwOURWAAAgMsjsAAAAJdHYAEAAC6PwAIAAFwegQUAALg8AgsuG7t371bnzp3l7+8vi8Wi5cuXO7X/ffv2yWKxKC4uzqn9Xs5uvfVW3XrrrWU9DFPw8wYuLwQWlMoff/yhRx99VHXr1pWnp6f8/PzUvn17zZo1SydPnjT12FFRUdq6dateeOEFvf/++2rTpo2px7uU+vfvL4vFIj8/v2Kv4+7du2WxWGSxWDRjxoxS95+SkqJJkyYpKSnJCaM116RJk2zner6XKwapgoICvffeewoNDVVgYKB8fX3VoEEDPfTQQ/rxxx9L3d+JEyc0adIkrVu3zvmDBS4z5cp6ALh8rFy5UnfffbesVqseeughNW3aVHl5efr+++81ZswYbd++XQsWLDDl2CdPnlRCQoKefvppDRs2zJRj1KpVSydPnlT58uVN6f9CypUrpxMnTuiLL77QPffcY7fvgw8+kKenp3Jyci6q75SUFD333HOqXbu2WrZsWeLPff311xd1PEf06tVL9evXt73PysrS0KFD1bNnT/Xq1cu2PSgoyKHjmPHzfvzxxzV37lx1795d999/v8qVK6fk5GR99dVXqlu3rtq1a1eq/k6cOKHnnntOklwyoAGXEoEFJbJ3717169dPtWrV0tq1a1WtWjXbvujoaO3Zs0crV6407fjp6emSpICAANOOYbFY5OnpaVr/F2K1WtW+fXt9+OGHRQLLokWLFBkZqU8++eSSjOXEiROqUKGCPDw8Lsnxzta8eXM1b97c9v7o0aMaOnSomjdvrgceeOCcn8vJyZGHh4fc3EpWOHb2zzstLU3z5s3Tww8/XCS4z5w50/Z3GMDFYUoIJTJ9+nRlZWXp7bfftgsrherXr68nnnjC9v706dOaMmWK6tWrJ6vVqtq1a+upp55Sbm6u3edq166trl276vvvv9cNN9wgT09P1a1bV++9956tzaRJk1SrVi1J0pgxY2SxWFS7dm1JZ6ZSCv98tsJphbOtWbNGN910kwICAuTj46OGDRvqqaeesu0/15qGtWvX6uabb5a3t7cCAgLUvXt37dy5s9jj7dmzR/3791dAQID8/f01YMAAnThx4twX9j/uu+8+ffXVV8rIyLBt27x5s3bv3q377ruvSPtjx45p9OjRatasmXx8fOTn56cuXbro119/tbVZt26d2rZtK0kaMGCAbUql8DxvvfVWNW3aVImJibrllltUoUIF23X57xqWqKgoeXp6Fjn/iIgIVaxYUSkpKSU+V0esW7dOFotFH330kZ555hldc801qlChgjIzM0t0TaTif979+/eXj4+PDh06pB49esjHx0dVqlTR6NGjlZ+ff94x7d27V4ZhqH379kX2WSwWVa1a1W5bRkaGRowYoRo1ashqtap+/fp68cUXVVBQYBtflSpVJEnPPfec7ec2adKki7hiwOWPCgtK5IsvvlDdunV14403lqj94MGDtXDhQvXp00ejRo3STz/9pNjYWO3cuVPLli2za7tnzx716dNHgwYNUlRUlN555x31799frVu31nXXXadevXopICBAI0eO1L333qs777xTPj4+pRr/9u3b1bVrVzVv3lyTJ0+W1WrVnj179MMPP5z3c9988426dOmiunXratKkSTp58qRee+01tW/fXlu2bCkSlu655x7VqVNHsbGx2rJli9566y1VrVpVL774YonG2atXLw0ZMkSffvqpBg4cKOlMdaVRo0Zq1apVkfZ//vmnli9frrvvvlt16tRRWlqa3njjDXXo0EE7duxQSEiIGjdurMmTJ2vChAl65JFHdPPNN0uS3c/y77//VpcuXdSvXz898MAD55xumTVrltauXauoqCglJCTI3d1db7zxhr7++mu9//77CgkJKdF5OsuUKVPk4eGh0aNHKzc3Vx4eHtqxY8cFr8n55OfnKyIiQqGhoZoxY4a++eYbvfzyy6pXr56GDh16zs8VhuqlS5fq7rvvVoUKFc7Z9sSJE+rQoYMOHTqkRx99VDVr1tTGjRs1fvx4HT58WDNnzlSVKlX0+uuvF5kOO7v6BFxVDOACjh8/bkgyunfvXqL2SUlJhiRj8ODBdttHjx5tSDLWrl1r21arVi1DkrFhwwbbtiNHjhhWq9UYNWqUbdvevXsNScZLL71k12dUVJRRq1atImOYOHGicfZf71dffdWQZKSnp59z3IXHePfdd23bWrZsaVStWtX4+++/bdt+/fVXw83NzXjooYeKHG/gwIF2ffbs2dOoVKnSOY959nl4e3sbhmEYffr0MTp16mQYhmHk5+cbwcHBxnPPPVfsNcjJyTHy8/OLnIfVajUmT55s27Z58+Yi51aoQ4cOhiRj/vz5xe7r0KGD3bbVq1cbkoznn3/e+PPPPw0fHx+jR48eFzzHi5Wenm5IMiZOnGjb9u233xqSjLp16xonTpywa1/Sa1LczzsqKsqQZNfOMAzj+uuvN1q3bn3BsT700EOGJKNixYpGz549jRkzZhg7d+4s0m7KlCmGt7e38fvvv9ttHzdunOHu7m4cOHDgnOcOXK2YEsIFZWZmSpJ8fX1L1P7LL7+UJMXExNhtHzVqlCQVWevSpEkT22/9klSlShU1bNhQf/7550WP+b8K17589tlntpL7hRw+fFhJSUnq37+/AgMDbdubN2+u22+/3XaeZxsyZIjd+5tvvll///237RqWxH333ad169YpNTVVa9euVWpqarHTQdKZdS+Fazby8/P1999/26a7tmzZUuJjWq1WDRgwoERtO3furEcffVSTJ09Wr1695OnpqTfeeKPEx3KmqKgoeXl52W1zxjUp7udYkr+P7777rubMmaM6depo2bJlGj16tBo3bqxOnTrp0KFDtnZLly7VzTffrIoVK+ro0aO2V3h4uPLz87Vhw4YSjRO4mhBYcEF+fn6SpH///bdE7ffv3y83Nze7Oz0kKTg4WAEBAdq/f7/d9po1axbpo2LFivrnn38ucsRF9e3bV+3bt9fgwYMVFBSkfv36acmSJecNL4XjbNiwYZF9jRs31tGjR5WdnW23/b/nUrFiRUkq1bnceeed8vX11eLFi/XBBx+obdu2Ra5loYKCAr366qu69tprZbVaVblyZVWpUkW//fabjh8/XuJjXnPNNaVaYDtjxgwFBgYqKSlJs2fPLrI+ozjp6elKTU21vbKyskp8vHOpU6dOkW2OXhNPT0/b2pFCJf376ObmpujoaCUmJuro0aP67LPP1KVLF61du1b9+vWztdu9e7dWrVqlKlWq2L3Cw8MlSUeOHLngsYCrDYEFF+Tn56eQkBBt27atVJ/776LXc3F3dy92u2EYF32M/y6Q9PLy0oYNG/TNN9/owQcf1G+//aa+ffvq9ttvv+BiytJw5FwKWa1W9erVSwsXLtSyZcvOWV2RpKlTpyomJka33HKL/u///k+rV6/WmjVrdN1115W4kiSpSJXiQn755Rfb/6hu3bq1RJ9p27atqlWrZntdzPNk/qu4cTt6Tc71MyytSpUq6a677tKXX36pDh066Pvvv7eF4IKCAt1+++1as2ZNsa/evXs7ZQzAlYRFtyiRrl27asGCBUpISFBYWNh529aqVUsFBQXavXu3GjdubNuelpamjIwM2+JEZ6hYsaLdHTWF/lvFkc789tupUyd16tRJr7zyiqZOnaqnn35a3377re032/+ehyQlJycX2bdr1y5VrlxZ3t7
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on training set\n",
"y_pred_train = best_xgb_model.predict(X_train)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 18,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 12:55:53 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAHHCAYAAACcHAM1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABQcklEQVR4nO3de1yO9/8H8Nfd6S7RCR3uITmlJqeQ5DhNjpPjwiaEsUJOoxlymCabOYuxZRtmbJph0TTayClyaDSHMOMuSaXofP3+8Ov6uhXKdV914/X8Pq7H43t/rvf1uT7XbW1v78/nc6UQBEEAERERkQ7Tq+wBEBERET0PExYiIiLSeUxYiIiISOcxYSEiIiKdx4SFiIiIdB4TFiIiItJ5TFiIiIhI5zFhISIiIp3HhIWIiIh0HhMWemldunQJ3bp1g7m5ORQKBSIiIrTa/7Vr16BQKBAeHq7Vfl9mnTt3RufOnSt7GET0GmLCQpJcuXIFH3zwAerVqwdjY2OYmZnBw8MDy5cvx8OHD2W9t6+vL86dO4dPP/0U3333HVq1aiXr/SrSiBEjoFAoYGZmVur3eOnSJSgUCigUCnz++efl7v/WrVsIDg5GfHy8FkYrr+DgYPFZn3VoK5Hau3cvgoODyxxfVFSEb7/9Fm5ubrCyskK1atXQqFEjDB8+HEePHi33/R88eIDg4GAcPHiw3NcSvcoMKnsA9PLas2cPBg0aBKVSieHDh6NJkybIy8vDX3/9henTpyMhIQHr16+X5d4PHz5EbGwsZs2ahYCAAFnuYW9vj4cPH8LQ0FCW/p/HwMAADx48wK+//orBgwdrnNu8eTOMjY2Rk5PzQn3funUL8+bNQ926ddG8efMyX7d///4Xup8U/fv3R4MGDcTPWVlZGD9+PPr164f+/fuL7TY2Nlq53969e7F69eoyJy0TJ07E6tWr0bdvXwwbNgwGBgZITEzEb7/9hnr16qFt27bluv+DBw8wb948AGA1i+gxTFjohSQlJcHHxwf29vaIjo6GnZ2deM7f3x+XL1/Gnj17ZLv/nTt3AAAWFhay3UOhUMDY2Fi2/p9HqVTCw8MDW7duLZGwbNmyBb169cJPP/1UIWN58OABqlSpAiMjowq53+OaNm2Kpk2bip9TU1Mxfvx4NG3aFO+9916Fj+dxycnJWLNmDcaMGVMiOV+2bJn4zykRSccpIXohoaGhyMrKwsaNGzWSlWINGjTApEmTxM8FBQVYsGAB6tevD6VSibp16+Ljjz9Gbm6uxnV169ZF79698ddff6FNmzYwNjZGvXr18O2334oxwcHBsLe3BwBMnz4dCoUCdevWBfBoKqX4/z+ueFrhcVFRUWjfvj0sLCxQtWpVODo64uOPPxbPP20NS3R0NDp06ABTU1NYWFigb9++uHDhQqn3u3z5MkaMGAELCwuYm5tj5MiRePDgwdO/2CcMHToUv/32G9LT08W2EydO4NKlSxg6dGiJ+LS0NEybNg0uLi6oWrUqzMzM0KNHD5w5c0aMOXjwIFq3bg0AGDlypDilUvycnTt3RpMmTRAXF4eOHTuiSpUq4vfy5BoWX19fGBsbl3h+Ly8vWFpa4tatW2V+VqkuXryIgQMHwsrKCsbGxmjVqhV27dqlEZOfn4958+ahYcOGMDY2RvXq1dG+fXtERUUBePTPz+rVqwFAY7rpaZKSkiAIAjw8PEqcUygUsLa21mhLT09HYGAgateuDaVSiQYNGmDx4sUoKioC8OifuZo1awIA5s2bJ96/PFNURK8qVljohfz666+oV68e2rVrV6b40aNHY9OmTRg4cCCmTp2KY8eOISQkBBcuXMDOnTs1Yi9fvoyBAwfCz88Pvr6++PrrrzFixAi4urrizTffRP/+/WFhYYHJkydjyJAh6NmzJ6pWrVqu8SckJKB3795o2rQp5s+fD6VSicuXL+Pw4cPPvO73339Hjx49UK9ePQQHB+Phw4dYuXIlPDw8cOrUqRLJ0uDBg+Hg4ICQkBCcOnUKGzZsgLW1NRYvXlymcfbv3x/jxo3Dzz//jFGjRgF4VF1p3LgxWrZsWSL+6tWriIiIwKBBg+Dg4IDk5GSsW7cOnTp1wt9//w2VSgUnJyfMnz8fc+bMwdixY9GhQwcA0PizvHv3Lnr06AEfHx+89957T51uWb58OaKjo+Hr64vY2Fjo6+tj3bp12L9/P7777juoVKoyPadUCQkJ8PDwwBtvvIGZM2fC1NQUP/74I7y9vfHTTz+hX79+AB4lkiEhIRg9ejTatGmDzMxMnDx5EqdOncLbb7+NDz74ALdu3UJUVBS+++675963OHHevn07Bg0ahCpVqjw19sGDB+jUqRP+++8/fPDBB6hTpw6OHDmCoKAg3L59G8uWLUPNmjWxdu3aElNej1eYiF5bAlE5ZWRkCACEvn37lik+Pj5eACCMHj1ao33atGkCACE6Olpss7e3FwAIMTExYltKSoqgVCqFqVOnim1JSUkCAGHJkiUaffr6+gr29vYlxjB37lzh8X/cv/zySwGAcOfOnaeOu/ge33zzjdjWvHlzwdraWrh7967YdubMGUFPT08YPnx4ifuNGjVKo89+/foJ1atXf+o9H38OU1NTQRAEYeDAgULXrl0FQRCEwsJCwdbWVpg3b16p30FOTo5QWFhY4jmUSqUwf/58se3EiRMlnq1Yp06dBABCWFhYqec6deqk0bZv3z4BgLBw4ULh6tWrQtWqVQVvb+/nPuOLunPnjgBAmDt3rtjWtWtXwcXFRcjJyRHbioqKhHbt2gkNGzYU25o1ayb06tXrmf37+/sL5flX4/DhwwUAgqWlpdCvXz/h888/Fy5cuFAibsGCBYKpqanwzz//aLTPnDlT0NfXF27cuPHU5yMiQeCUEJVbZmYmAKBatWplit+7dy8AYMqUKRrtU6dOBYASa12cnZ3Fv/UDQM2aNeHo6IirV6++8JifVLz25ZdffhHL8c9z+/ZtxMfHY8SIEbCyshLbmzZtirffflt8zseNGzdO43OHDh1w9+5d8Tssi6FDh+LgwYNQq9WIjo6GWq0udToIeLTuRU/v0Y91YWEh7t69K053nTp1qsz3VCqVGDlyZJliu3Xrhg8++ADz589H//79YWxsjHXr1pX5XlKlpaUhOjoagwcPxv3795GamorU1FTcvXsXXl5euHTpEv777z8Aj/7cExIScOnSJa3d/5tvvsGqVavg4OCAnTt3Ytq0aXByckLXrl3F+wKPqjAdOnSApaWlOMbU1FR4enqisLAQMTExWhsT0auICQuVm5mZGQDg/v37ZYq/fv069PT0NHZ6AICtrS0sLCxw/fp1jfY6deqU6MPS0hL37t17wRGX9O6778LDwwOjR4+GjY0NfHx88OOPPz4zeSkep6OjY4lzTk5OSE1NRXZ2tkb7k89iaWkJAOV6lp49e6JatWrYtm0bNm/ejNatW5f4LosVFRXhyy+/RMOGDaFUKlGjRg3UrFkTZ8+eRUZGRpnv+cYbb5Rrge3nn38OKysrxMfHY8WKFSXWbpTmzp07UKvV4pGVlVXm+z3u8uXLEAQBs2fPRs2aNTWOuXPnAgBSUlIAAPPnz0d6ejoaNWoEFxcXTJ8+HWfPnn2h+xbT09ODv78/4uLikJqail9++QU9evRAdHQ0fHx8xLhLly4hMjKyxBg9PT01xkhEpeMaFio3MzMzqFQqnD9/vlzXPWvx4uP09fVLbRcE4YXvUVhYqPHZxMQEMTEx+OOPP7Bnzx5ERkZi27ZteOutt7B///6njqG8pDxLMaVSif79+2PTpk24evXqMxdgLlq0CLNnz8aoUaOwYMECWFlZQU9PD4GBgWWuJAGPvp/yOH36tPgf3HPnzmHIkCHPvaZ169YayercuXNfaHFp8XNNmzYNXl5epcYUJ3gdO3bElStX8Msvv2D//v3YsGEDvvzyS4SFhWH06NHlvveTqlevjnfeeQfvvPMOOnfujEOHDuH69euwt7dHUVER3n77bXz00UelXtuoUSPJ9yd6lTFhoRfSu3dvrF+/HrGxsXB3d39mbPG/rC9dugQnJyexPTk5Gen
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on test set\n",
"y_pred_test = best_xgb_model.predict(X_test)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 19,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-20 09:26:51 +01:00
"The test dataset contains 16565 shots, with 1891 of them being goals.\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Number of goals in test set\n",
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
2023-12-12 15:22:01 +01:00
"metadata": {},
"source": [
"## Feature importance"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 20,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 12:55:53 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAHHCAYAAADONqsSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1wUx//48dfROwKCgiJFEbGBXTG2qIAYRY2xxChY8GPBGnsD7F2iGDXRiDVqTCwxREXE3ntHscfeEVDq/P7gx34571CMJJbM8/HgATs7OzP73uPuZnd2ViWEEEiSJEmSJEmSJEkfDJ333QBJkiRJkiRJkiRJneyoSZIkSZIkSZIkfWBkR02SJEmSJEmSJOkDIztqkiRJkiRJkiRJHxjZUZMkSZIkSZIkSfrAyI6aJEmSJEmSJEnSB0Z21CRJkiRJkiRJkj4wsqMmSZIkSZIkSZL0gZEdNUmSJEmSJEmSpA+M7KhJkiRJ0kckKioKlUrFtWvX3ndTJEmSpH+Q7KhJkiRJH7Scjom2n2HDhv0jde7bt4+wsDCePn36j5T/X5aSkkJYWBg7dux4302RJEn6oOm97wZIkiRJUn6MHTsWFxcXtbTy5cv/I3Xt27eP8PBwgoKCKFSo0D9Sx9/VsWNH2rVrh6Gh4ftuyt+SkpJCeHg4APXr13+/jZEkSfqAyY6aJEmS9FFo0qQJVatWfd/NeCfJycmYmpq+Uxm6urro6uoWUIv+PVlZWaSlpb3vZkiSJH005NBHSZIk6ZPw559/UqdOHUxNTTE3N6dp06acPXtWLc+pU6cICgrC1dUVIyMjihYtSpcuXXj06JGSJywsjMGDBwPg4uKiDLO8du0a165dQ6VSERUVpVG/SqUiLCxMrRyVSsW5c+f4+uuvsbKy4rPPPlPWL1++nCpVqmBsbIy1tTXt2rXj5s2bb9xPbfeoOTs788UXX7Bjxw6qVq2KsbExFSpUUIYX/vbbb1SoUAEjIyOqVKnC8ePH1coMCgrCzMyMK1eu4Ovri6mpKQ4ODowdOxYhhFre5ORkvv32WxwdHTE0NMTd3Z3p06dr5FOpVISEhLBixQrKlSuHoaEh8+fPx9bWFoDw8HAltjlxy8/xyR3bhIQE5aqnpaUlnTt3JiUlRSNmy5cvp3r16piYmGBlZUXdunXZunWrWp78vH4kSZL+TfKKmiRJkvRRePbsGQ8fPlRLK1y4MADLli0jMDAQX19fpkyZQkpKCvPmzeOzzz7j+PHjODs7AxATE8OVK1fo3LkzRYsW5ezZs/zwww+cPXuWAwcOoFKpaNWqFRcvXuTnn39m1qxZSh22trY8ePDgrdv91Vdf4ebmxsSJE5XOzIQJExg9ejRt2rShW7duPHjwgDlz5lC3bl2OHz/+t4ZbJiQk8PXXX/O///2Pb775hunTp9OsWTPmz5/PiBEj6NWrFwCTJk2iTZs2xMfHo6Pzf+drMzMz8fPzo2bNmkydOpXNmzcTGhpKRkYGY8eOBUAIQfPmzYmLi6Nr1654eXmxZcsWBg8ezK1bt5g1a5Zam7Zv386aNWsICQmhcOHCeHp6Mm/ePHr27EnLli1p1aoVABUrVgTyd3xya9OmDS4uLkyaNIljx46xcOFC7OzsmDJlipInPDycsLAwvL29GTt2LAYGBhw8eJDt27fj4+MD5P/1I0mS9K8SkiRJkvQBW7x4sQC0/gghxPPnz0WhQoVEcHCw2nZ3794VlpaWaukpKSka5f/8888CELt27VLSpk2bJgBx9epVtbxXr14VgFi8eLFGOYAIDQ1VlkNDQwUg2rdvr5bv2rVrQldXV0yYMEEt/fTp00JPT08jPa945G6bk5OTAMS+ffuUtC1btghAGBsbi+vXryvpCxYsEICIi4tT0gIDAwUg+vTpo6RlZWWJpk2bCgMDA/HgwQMhhBDr168XgBg/frxam1q3bi1UKpVISEhQi4eOjo44e/asWt4HDx5oxCpHfo9PTmy7dOmilrdly5bCxsZGWb506ZLQ0dERLVu2FJmZmWp5s7KyhBBv9/qRJEn6N8mhj5IkSdJHYe7cucTExKj9QPZVmKdPn9K+fXsePnyo/Ojq6lKjRg3i4uKUMoyNjZW/X758ycOHD6lZsyYAx44d+0fa3aNHD7Xl3377jaysLNq0aaPW3qJFi+Lm5qbW3rdRtmxZatWqpSzXqFEDgM8//5wSJUpopF+5ckWjjJCQEOXvnKGLaWlpbNu2DYDo6Gh0dXXp27ev2nbffvstQgj+/PNPtfR69epRtmzZfO/D2x6fV2Nbp04dHj16RGJiIgDr168nKyuLMWPGqF09zNk/eLvXjyRJ0r9JDn2UJEmSPgrVq1fXOpnIpUuXgOwOiTYWFhbK348fPyY8PJxVq1Zx//59tXzPnj0rwNb+n1dnqrx06RJCCNzc3LTm19fX/1v15O6MAVhaWgLg6OioNf3Jkydq6To6Ori6uqqllS5dGkC5H+769es4ODhgbm6uls/Dw0NZn9ur+/4mb3t8Xt1nKysrIHvfLCwsuHz5Mjo6Oq/tLL7N60eSJOnfJDtqkiRJ0kctKysLyL7PqGjRohrr9fT+76OuTZs27Nu3j8GDB+Pl5YWZmRlZWVn4+fkp5bzOq/dI5cjMzMxzm9xXiXLaq1Kp+PPPP7XO3mhmZvbGdmiT10yQeaWLVyb/+Ce8uu9v8rbHpyD27W1eP5IkSf8m+e4jSZIkfdRKliwJgJ2dHY0aNcoz35MnT4iNjSU8PJwxY8Yo6TlXVHLLq0OWc8Xm1Qdhv3ol6U3tFULg4uKiXLH6EGRlZXHlyhW1Nl28eBFAmUzDycmJbdu28fz5c7WrahcuXFDWv0lesX2b45NfJUuWJCsri3PnzuHl5ZVnHnjz60eSJOnfJu9RkyRJkj5qvr6+WFhYMHHiRNLT0zXW58zUmHP15dWrLRERERrb5Dzr7NUOmYWFBYULF2bXrl1q6d9//32+29uqVSt0dXUJDw/XaIsQQmMq+n9TZGSkWlsiIyPR19enYcOGAPj7+5OZmamWD2DWrFmoVCqaNGnyxjpMTEwAzdi+zfHJrxYtWqCjo8PYsWM1rsjl1JPf148kSdK/TV5RkyRJkj5qFhYWzJs3j44dO1K5cmXatWuHra0tN27c4I8//qB27dpERkZiYWFB3bp1mTp1Kunp6RQrVoytW7dy9epVjTKrVKkCwMiRI2nXrh36+vo0a9YMU1NTunXrxuTJk+nWrRtVq1Zl165dypWn/ChZsiTjx49n+PDhXLt2jRYtWmBubs7Vq1dZt24d3bt3Z9CgQQUWn/wyMjJi8+bNBAYGUqNGDf7880/++OMPRowYoTz7rFmzZjRo0ICRI0dy7do1PD092bp1Kxs2bKB///7K1anXMTY2pmzZsqxevZrSpUtjbW1N+fLlKV++fL6PT36VKlWKkSNHMm7cOOrUqUOrVq0wNDTk8OHDODg4MGnSpHy/fiRJkv5tsqMmSZIkffS+/vprHBwcmDx5MtOmTSM1NZVixYpRp04dOnfurORbuXIlffr0Ye7cuQgh8PHx4c8//8TBwUGtvGrVqjFu3Djmz5/P5s2bycrK4urVq5iamjJmzBgePHjA2rVrWbNmDU2aNOHPP//Ezs4u3+0dNmwYpUuXZtasWYSHhwPZk374+PjQvHnzggnKW9LV1WXz5s307NmTwYMHY25uTmhoqNowRB0dHTZu3MiYMWNYvXo1ixcvxtnZmWnTpvHtt9/mu66FCxfSp08fBgwYQFpaGqGhoZQvXz7fx+dtjB07FhcXF+bMmcPIkSMxMTGhYsWKdOzYUcmT39ePJEnSv0kl/o27iSVJkiRJ+mAFBQWxdu1akpKS3ndTJEmSpP9P3qMmSZIkSZIkSZL0gZEdNUmSJEmSJEmSpA+M7KhJkiRJkiRJkiR9YOQ9apIkSZIkSZIkSR8YeUVNkiRJkiRJkiTpAyM7apIkSZIkSZIkSR8Y+Rw1SfqPy8rK4vbt25ibm6NSqd53cyRJkiRJygchBM+fP8fBwQEdHXnt5VMkO2qS9B93+/ZtHB0d33czJEmSJEn6G27evEnx4sX
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Gain\n",
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 21,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-01-20 12:55:53 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAv8AAAHHCAYAAAAoOjd/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1QVx//w8felSheQalBQEVARsQsWVARsX3tXxIKxYI01KsWGXaImmiZYY4o1EQsasYuKXRQrwagEC6JApO7zBw/740oRWwSd1zkc2dnd2dmdK/ezs7MzCkmSJARBEARBEARB+OipfOgCCIIgCIIgCILw3xDBvyAIgiAIgiB8IkTwLwiCIAiCIAifCBH8C4IgCIIgCMInQgT/giAIgiAIgvCJEMG/IAiCIAiCIHwiRPAvCIIgCIIgCJ8IEfwLgiAIgiAIwidCBP+CIAiCIAiC8IkQwb8gCIIglCFhYWEoFAri4uI+dFEEQSiDRPAvCIIglGp5wW5hP1OnTn0vxzx+/DiBgYE8ffr0veT/KUtLSyMwMJDIyMgPXRRB+CSpfegCCIIgCEJJzJo1CxsbG6W0WrVqvZdjHT9+nKCgIHx8fChfvvx7OcabGjBgAL1790ZTU/NDF+WNpKWlERQUBICbm9uHLYwgfIJE8C8IgiCUCW3btqV+/fofuhhvJTU1FR0dnbfKQ1VVFVVV1XdUov9OTk4OGRkZH7oYgvDJE91+BEEQhI/C7t27adasGTo6Oujp6dG+fXuuXLmitM3Fixfx8fGhSpUqlCtXDnNzcwYPHszjx4/lbQIDA5k0aRIANjY2chejuLg44uLiUCgUhIWFFTi+QqEgMDBQKR+FQkFMTAx9+/bF0NCQpk2byus3bNhAvXr10NLSwsjIiN69e3P37t1Xnmdhff6tra3p0KEDkZGR1K9fHy0tLRwdHeWuNVu3bsXR0ZFy5cpRr149zp07p5Snj48Purq63L59G09PT3R0dLC0tGTWrFlIkqS0bWpqKl988QVWVlZoampiZ2fH4sWLC2ynUCjw8/Nj48aN1KxZE01NTVavXo2JiQkAQUFB8rXNu24lqZ/81/bmzZvy0xkDAwMGDRpEWlpagWu2YcMGGjZsiLa2NoaGhjRv3px9+/YpbVOSz48gfAxEy78gCIJQJiQnJ/Po0SOltAoVKgCwfv16Bg4ciKenJwsWLCAtLY1Vq1bRtGlTzp07h7W1NQARERHcvn2bQYMGYW5uzpUrV/juu++4cuUKJ0+eRKFQ0LVrV65fv85PP/3EsmXL5GOYmJjw8OHD1y53jx49sLW1Zd68eXKAPHfuXGbOnEnPnj0ZOnQoDx8+ZMWKFTRv3pxz5869UVejmzdv0rdvXz7//HP69+/P4sWL6dixI6tXr+bLL79k5MiRAAQHB9OzZ09iY2NRUfm/NsDs7Gy8vLxo3LgxCxcuZM+ePQQEBJCVlcWsWbMAkCSJ//3vfxw8eJAhQ4ZQp04d9u7dy6RJk7h37x7Lli1TKtOff/7JL7/8gp+fHxUqVMDJyYlVq1YxYsQIunTpQteuXQGoXbs2ULL6ya9nz57Y2NgQHBzM2bNn+eGHHzA1NWXBggXyNkFBQQQGBuLi4sKsWbPQ0NAgKiqKP//8Ew8PD6Dknx9B+ChIgiAIglCKhYaGSkChP5IkSc+fP5fKly8v+fr6Ku2XkJAgGRgYKKWnpaUVyP+nn36SAOnw4cNy2qJFiyRAunPnjtK2d+7ckQApNDS0QD6AFBAQIC8HBARIgNSnTx+l7eLi4iRVVVVp7ty5SumXLl2S1NTUCqQXdT3yl61y5coSIB0/flxO27t3rwRIWlpa0l9//SWnf/vttxIgHTx4UE4bOHCgBEijR4+W03JycqT27dtLGhoa0sOHDyVJkqTt27dLgDRnzhylMnXv3l1SKBTSzZs3la6HioqKdOXKFaVtHz58WOBa5Slp/eRd28GDBytt26VLF8nY2FhevnHjhqSioiJ16dJFys7OVto2JydHkqTX+/wIwsdAdPsRBEEQyoSvv/6aiIgIpR/IbS1++vQpffr04dGjR/KPqqoqjRo14uDBg3IeWlpa8u8vXrzg0aNHNG7cGICzZ8++l3IPHz5caXnr1q3k5OTQs2dPpfKam5tja2urVN7XUaNGDZo0aSIvN2rUCIBWrVpRqVKlAum3b98ukIefn5/8e163nYyMDPbv3w9AeHg4qqqqjBkzRmm/L774AkmS2L17t1J6ixYtqFGjRonP4XXr5+Vr26xZMx4/fsyzZ88A2L59Ozk5Ofj7+ys95cg7P3i9z48gfAxEtx9BEAShTGjYsGGhL/zeuHEDyA1yC6Ovry///uTJE4KCgti8eTOJiYlK2yUnJ7/D0v6fl0counHjBpIkYWtrW+j26urqb3Sc/AE+gIGBAQBWVlaFpiclJSmlq6ioUKVKFaW06tWrA8jvF/z1119YWlqip6entJ2Dg4O8Pr+Xz/1VXrd+Xj5nQ0NDIPfc9PX1uXXrFioqKsXegLzO50cQPgYi+BcEQRDKtJycHCC337a5uXmB9Wpq//dV17NnT44fP86kSZOoU6cOurq65OTk4OXlJedTnJf7nOfJzs4ucp/8rdl55VUoFOzevbvQUXt0dXVfWY7CFDUCUFHp0ksv6L4PL5/7q7xu/byLc3udz48gfAzEJ1oQBEEo06pWrQqAqakp7u7uRW6XlJTEgQMHCAoKwt/fX07Pa/nNr6ggP69l+eXJv15u8X5VeSVJwsbGRm5ZLw1ycnK4ffu2UpmuX78OIL/wWrlyZfbv38/z58+VWv+vXbsmr3+Voq7t69RPSVWtWpWcnBxiYmKoU6dOkdvAqz8/gvCxEH3+BUEQhDLN09MTfX195s2bR2ZmZoH1eSP05LUSv9wqHBISUmCfvLH4Xw7y9fX1qVChAocPH1ZK/+abb0pc3q5du6KqqkpQUFCBskiSVGBYy//SypUrlcqycuVK1NXVad26NQDt2rUjOztbaTuAZcuWoVAoaNu27SuPoa2tDRS8tq9TPyXVuXNnVFRUmDVrVoEnB3nHKennRxA+FqLlXxAEQSjT9PX1WbVqFQMGDKBu3br07t0bExMT4uPj2bVrF66urqxcuRJ9fX2aN2/OwoULyczMpGLFiuzbt487d+4UyLNevXoATJ8+nd69e6Ourk7Hjh3R0dFh6NChzJ8/n6FDh1K/fn0OHz4st5CXRNWqVZkzZw7Tpk0jLi6Ozp07o6enx507d9i2bRvDhg1j4sSJ7+z6lFS5cuXYs2cPAwcOpFGjRuzevZtdu3bx5ZdfymPzd+zYkZYtWzJ9+nTi4uJwcnJi37597Nixg3Hjxsmt6MXR0tKiRo0a/Pzzz1SvXh0jIyNq1apFrVq1Slw/JVWtWjWmT5/O7NmzadasGV27dkVTU5PTp09jaWlJcHBwiT8/gvDR+ECjDAmCIAhCieQNbXn69Olitzt48KDk6ekpGRgYSOXKlZOqVq0q+fj4SGfOnJG3+fvvv6UuXbpI5cuXlwwMDKQePXpI9+/fL3ToydmzZ0sVK1aUVFRUlIbWTEtLk4YMGSIZGBhIenp6Us+ePaXExMQih/rMGybzZVu2bJGaNm0q6ejoSDo6OpK9vb00atQoKTY2tkTX4+WhPtu3b19gW0AaNWqUUlrecKWLFi2S0wYOHCjp6OhIt27dkjw8PCRtbW3JzMxMCggIKDBE5vPnz6Xx48dLlpaWkrq6umRraystWrRIHjqzuGPnOX78uFSvXj1JQ0ND6bqVtH6KuraFXRtJkqQ1a9ZIzs7OkqampmRoaCi1aNFCioiIUNqmJJ8fQfgYKCTpP3jjRxAEQRCEUsvHx4fffvuNlJSUD10UQRDeM9HnXxAEQRAEQRA+ESL4FwRBEARBEIRPhAj+BUEQBEEQBOETIfr8C4IgCIIgCMInQrT8C4IgCIIgCMInQgT/giAIgiAIgvCJEJN8CcInLicnh/v376Onp4dCofjQxREEQRAEoQQkSeL58+dYWlqiolLy9nwR/AvCJ+7+/ftYWVl96GIIgiAIgvAG7t69y2effVbi7UXwLwifOD0
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Weight\n",
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 22,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Calculating MAE, RMSE and R2 for training and test sets \n",
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
"r2_train = r2_score(y_train, y_pred_train)\n",
"\n",
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
"r2_test = r2_score(y_test, y_pred_test)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 23,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-01-20 12:55:53 +01:00
"#T_295e5_row0_col0, #T_295e5_row0_col1, #T_295e5_row0_col2, #T_295e5_row0_col3, #T_295e5_row0_col4, #T_295e5_row0_col5, #T_295e5_row0_col6 {\n",
2023-12-12 15:22:01 +01:00
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
2024-01-20 12:55:53 +01:00
"<table id=\"T_295e5\">\n",
2023-12-12 15:22:01 +01:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-01-20 12:55:53 +01:00
" <th id=\"T_295e5_level0_col0\" class=\"col_heading level0 col0\" >Training MAE</th>\n",
" <th id=\"T_295e5_level0_col1\" class=\"col_heading level0 col1\" >Training RMSE</th>\n",
" <th id=\"T_295e5_level0_col2\" class=\"col_heading level0 col2\" >Training R2</th>\n",
" <th id=\"T_295e5_level0_col3\" class=\"col_heading level0 col3\" >Testing MAE</th>\n",
" <th id=\"T_295e5_level0_col4\" class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
" <th id=\"T_295e5_level0_col5\" class=\"col_heading level0 col5\" >Testing R2</th>\n",
" <th id=\"T_295e5_level0_col6\" class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-01-20 12:55:53 +01:00
" <th id=\"T_295e5_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_295e5_row0_col0\" class=\"data row0 col0\" >0.09622</td>\n",
" <td id=\"T_295e5_row0_col1\" class=\"data row0 col1\" >0.31019</td>\n",
" <td id=\"T_295e5_row0_col2\" class=\"data row0 col2\" >0.01693</td>\n",
" <td id=\"T_295e5_row0_col3\" class=\"data row0 col3\" >0.09949</td>\n",
" <td id=\"T_295e5_row0_col4\" class=\"data row0 col4\" >0.31542</td>\n",
" <td id=\"T_295e5_row0_col5\" class=\"data row0 col5\" >0.01620</td>\n",
" <td id=\"T_295e5_row0_col6\" class=\"data row0 col6\" >25.84271</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-01-20 12:55:53 +01:00
"<pandas.io.formats.style.Styler at 0x156feb3d0>"
2023-12-12 15:22:01 +01:00
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 23,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training MAE': mae_train, \n",
" 'Training RMSE': rmse_train,\n",
" 'Training R2':r2_train,\n",
" 'Testing MAE': mae_test, \n",
" 'Testing RMSE': rmse_test,\n",
" 'Testing R2':r2_test,\n",
" 'Training Time (mins)': xgb_training_time/60})\n",
2023-12-12 15:22:01 +01:00
"summary_df.set_index('Model Name', inplace=True)\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Displaying summary of results\n",
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
2023-12-12 15:22:01 +01:00
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keeping the xgboost model"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
2024-01-20 09:26:51 +01:00
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'objective': 'binary:logistic',\n",
" 'base_score': None,\n",
" 'booster': None,\n",
" 'callbacks': None,\n",
" 'colsample_bylevel': None,\n",
" 'colsample_bynode': None,\n",
" 'colsample_bytree': None,\n",
" 'device': None,\n",
" 'early_stopping_rounds': None,\n",
" 'enable_categorical': True,\n",
" 'eval_metric': None,\n",
" 'feature_types': None,\n",
" 'gamma': None,\n",
" 'grow_policy': None,\n",
" 'importance_type': None,\n",
" 'interaction_constraints': None,\n",
" 'learning_rate': 0.01,\n",
" 'max_bin': None,\n",
" 'max_cat_threshold': None,\n",
" 'max_cat_to_onehot': None,\n",
" 'max_delta_step': None,\n",
2024-01-20 12:55:53 +01:00
" 'max_depth': 5,\n",
" 'max_leaves': None,\n",
" 'min_child_weight': None,\n",
" 'missing': nan,\n",
" 'monotone_constraints': None,\n",
" 'multi_strategy': None,\n",
2024-01-20 12:55:53 +01:00
" 'n_estimators': 250,\n",
" 'n_jobs': None,\n",
" 'num_parallel_tree': None,\n",
" 'random_state': None,\n",
" 'reg_alpha': None,\n",
" 'reg_lambda': None,\n",
" 'sampling_method': None,\n",
" 'scale_pos_weight': 1,\n",
" 'subsample': None,\n",
" 'tree_method': 'hist',\n",
" 'validate_parameters': None,\n",
" 'verbosity': None}"
]
},
2024-01-20 09:26:51 +01:00
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the model\n",
"best_xgb_model.save_model('xgboost.json')\n",
"\n",
"best_xgb_model.get_params()"
]
},
{
"cell_type": "code",
"execution_count": 69,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"# dump(best_xgb_model, 'xgboost.joblib') \n",
2023-12-12 15:22:01 +01:00
"\n",
"# # Load the model\n",
"# model = load('xgboost.joblib')"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-01-20 12:55:53 +01:00
"array([[0.9005616 , 0.09943844],\n",
" [0.938443 , 0.06155698],\n",
" [0.96434194, 0.03565805],\n",
" ...,\n",
2024-01-20 12:55:53 +01:00
" [0.8334278 , 0.1665722 ],\n",
" [0.8697108 , 0.13028917],\n",
" [0.96131223, 0.03868778]], dtype=float32)"
]
},
2024-01-20 12:55:53 +01:00
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_xgb_model.predict_proba(X)"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"new_xgb_model = xgboost.XGBClassifier()\n",
"new_xgb_model.load_model('xgboost.json')"
]
},
{
"cell_type": "code",
2024-01-20 12:55:53 +01:00
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-01-20 12:55:53 +01:00
"array([[0.9005616 , 0.09943844],\n",
" [0.938443 , 0.06155698],\n",
" [0.96434194, 0.03565805],\n",
" ...,\n",
2024-01-20 12:55:53 +01:00
" [0.8334278 , 0.1665722 ],\n",
" [0.8697108 , 0.13028917],\n",
" [0.96131223, 0.03868778]], dtype=float32)"
]
},
2024-01-20 12:55:53 +01:00
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_xgb_model.predict_proba(X)"
2023-12-12 15:22:01 +01:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-20 09:26:51 +01:00
"version": "3.11.6"
2023-12-12 15:22:01 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}