fantastyczne_gole/notebooks/xgboost_dla_xG.ipynb

1465 lines
364 KiB
Plaintext
Raw Normal View History

2023-12-12 15:22:01 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 39,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n",
"import xgboost\n",
"import time\n",
2024-01-12 20:39:14 +01:00
"from joblib import dump, load\n",
"import os"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the data"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 40,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
2024-01-12 20:39:14 +01:00
"df = pd.read_csv('final_data.csv')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 41,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n",
" 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n",
2024-01-12 20:39:14 +01:00
" 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n",
" 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n",
" 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n",
" 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n",
" 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n",
" 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n",
" 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n",
" 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n",
" 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n",
" 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n",
" 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n",
" 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n",
" 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n",
" 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n",
" 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n",
" 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n",
" 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n",
" 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n",
" 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n",
" 'y_player_teammate_Goalkeeper'],\n",
" dtype='object')"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
2023-12-12 15:22:01 +01:00
"source": [
"df.columns"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 42,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
2024-01-12 20:39:14 +01:00
" <th>shot_follows_dribble</th>\n",
2023-12-12 15:22:01 +01:00
" <th>...</th>\n",
2024-01-12 20:39:14 +01:00
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>Right Center Forward</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>Center Midfield</td>\n",
" <td>Right Foot</td>\n",
" <td>Half Volley</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
2024-01-12 20:39:14 +01:00
" <td>48.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>Left Center Midfield</td>\n",
" <td>Right Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Right Center Back</td>\n",
" <td>Left Foot</td>\n",
" <td>Normal</td>\n",
2023-12-12 15:22:01 +01:00
" <td>Open Play</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
2023-12-12 15:22:01 +01:00
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-01-12 20:39:14 +01:00
" <td>NaN</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-01-12 20:39:14 +01:00
"<p>5 rows × 64 columns</p>\n",
2023-12-12 15:22:01 +01:00
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
"0 0 Right Center Forward Right Foot Normal \n",
"1 5 Right Center Forward Left Foot Normal \n",
"2 5 Center Midfield Right Foot Half Volley \n",
"3 5 Left Center Midfield Right Foot Normal \n",
"4 5 Right Center Back Left Foot Normal \n",
2023-12-12 15:22:01 +01:00
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
"0 Open Play False False False \n",
"1 Open Play False False False \n",
"2 Open Play True False False \n",
"3 Open Play False False False \n",
"4 Open Play True False False \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
"1 False False ... NaN \n",
"2 False False ... 48.9 \n",
"3 False False ... NaN \n",
"4 False False ... NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
2023-12-12 15:22:01 +01:00
"\n",
2024-01-12 20:39:14 +01:00
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"\n",
"[5 rows x 64 columns]"
2023-12-12 15:22:01 +01:00
]
},
"execution_count": 42,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data preparation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
2023-12-12 15:22:01 +01:00
"source": [
"# df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']] = df[['minute', \n",
"# 'number_of_players_opponents', \n",
"# 'number_of_players_teammates']].astype(float)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OrdinalEncoder\n",
"\n",
"enc = OrdinalEncoder()\n",
"\n",
"df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = enc.fit_transform(df[['position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name',\n",
" 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']] = df[['minute', \n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'shot_body_part_name']].astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>minute</th>\n",
" <th>position_name</th>\n",
" <th>shot_body_part_name</th>\n",
" <th>shot_technique_name</th>\n",
" <th>shot_type_name</th>\n",
" <th>shot_first_time</th>\n",
" <th>shot_one_on_one</th>\n",
" <th>shot_aerial_won</th>\n",
" <th>shot_open_goal</th>\n",
" <th>shot_follows_dribble</th>\n",
" <th>...</th>\n",
" <th>y_player_teammate_5</th>\n",
" <th>y_player_teammate_6</th>\n",
" <th>y_player_teammate_7</th>\n",
" <th>y_player_teammate_8</th>\n",
" <th>y_player_teammate_9</th>\n",
" <th>y_player_teammate_10</th>\n",
" <th>x_player_opponent_7</th>\n",
" <th>y_player_opponent_7</th>\n",
" <th>x_player_teammate_Goalkeeper</th>\n",
" <th>y_player_teammate_Goalkeeper</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>18</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>18</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>48.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>10</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>17</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38337</th>\n",
" <td>61</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>47.3</td>\n",
" <td>42.6</td>\n",
" <td>54.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.3</td>\n",
" <td>50.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38338</th>\n",
" <td>66</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>47.6</td>\n",
" <td>39.4</td>\n",
" <td>43.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>19.0</td>\n",
" <td>45.8</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38339</th>\n",
" <td>73</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>48.9</td>\n",
" <td>48.1</td>\n",
" <td>41.1</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.7</td>\n",
" <td>29.6</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38340</th>\n",
" <td>75</td>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>29.1</td>\n",
" <td>33.6</td>\n",
" <td>40.9</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>21.2</td>\n",
" <td>32.4</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38341</th>\n",
" <td>90</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>62.6</td>\n",
" <td>51.0</td>\n",
" <td>66.7</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>23.0</td>\n",
" <td>45.5</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>38342 rows × 64 columns</p>\n",
"</div>"
],
"text/plain": [
" minute position_name shot_body_part_name shot_technique_name \\\n",
"0 0 18 3 4 \n",
"1 5 18 1 4 \n",
"2 5 4 3 2 \n",
"3 5 10 3 4 \n",
"4 5 17 1 4 \n",
"... ... ... ... ... \n",
"38337 61 3 3 4 \n",
"38338 66 0 3 4 \n",
"38339 73 3 1 6 \n",
"38340 75 13 1 4 \n",
"38341 90 3 3 4 \n",
"\n",
" shot_type_name shot_first_time shot_one_on_one shot_aerial_won \\\n",
"0 3 False False False \n",
"1 3 False False False \n",
"2 3 True False False \n",
"3 3 False False False \n",
"4 3 True False False \n",
"... ... ... ... ... \n",
"38337 3 True False False \n",
"38338 3 True False False \n",
"38339 3 True False False \n",
"38340 3 False False False \n",
"38341 3 False False False \n",
"\n",
" shot_open_goal shot_follows_dribble ... y_player_teammate_5 \\\n",
"0 False False ... NaN \n",
"1 False False ... NaN \n",
"2 False False ... 48.9 \n",
"3 False False ... NaN \n",
"4 False False ... NaN \n",
"... ... ... ... ... \n",
"38337 False False ... 47.3 \n",
"38338 False False ... 47.6 \n",
"38339 False False ... 48.9 \n",
"38340 False False ... 29.1 \n",
"38341 False False ... 62.6 \n",
"\n",
" y_player_teammate_6 y_player_teammate_7 y_player_teammate_8 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"38337 42.6 54.7 NaN \n",
"38338 39.4 43.1 NaN \n",
"38339 48.1 41.1 NaN \n",
"38340 33.6 40.9 NaN \n",
"38341 51.0 66.7 NaN \n",
"\n",
" y_player_teammate_9 y_player_teammate_10 x_player_opponent_7 \\\n",
"0 NaN NaN NaN \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"... ... ... ... \n",
"38337 NaN NaN 21.3 \n",
"38338 NaN NaN 19.0 \n",
"38339 NaN NaN 21.7 \n",
"38340 NaN NaN 21.2 \n",
"38341 NaN NaN 23.0 \n",
"\n",
" y_player_opponent_7 x_player_teammate_Goalkeeper \\\n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN \n",
"... ... ... \n",
"38337 50.9 NaN \n",
"38338 45.8 NaN \n",
"38339 29.6 NaN \n",
"38340 32.4 NaN \n",
"38341 45.5 NaN \n",
"\n",
" y_player_teammate_Goalkeeper \n",
"0 NaN \n",
"1 NaN \n",
"2 NaN \n",
"3 NaN \n",
"4 NaN \n",
"... ... \n",
"38337 NaN \n",
"38338 NaN \n",
"38339 NaN \n",
"38340 NaN \n",
"38341 NaN \n",
"\n",
"[38342 rows x 64 columns]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['labelEncoder.joblib']"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dump(enc,'labelEncoder.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"enc2 = load('labelEncoder.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']] = enc2.transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name', \n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# enc.inverse_transform(df[['position_name', \n",
"# 'shot_technique_name', \n",
"# 'shot_type_name',\n",
"# 'shot_body_part_name']])"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# ############### NEW ################\n",
"# from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# le_posiotion_name = LabelEncoder()\n",
"# le_shot_technique_name = LabelEncoder()\n",
"# le_shot_type_name = LabelEncoder()\n",
"# le_shot_body_part_name = LabelEncoder()\n",
"\n",
"# df['position_name'] = le_posiotion_name.fit_transform(df['position_name'])\n",
"# df['shot_technique_name'] = le_shot_technique_name.fit_transform(df['shot_technique_name'])\n",
"# df['shot_type_name'] = le_shot_type_name.fit_transform(df['shot_type_name'])\n",
"# df['shot_body_part_name'] = le_shot_body_part_name.fit_transform(df['shot_body_part_name'])"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# Change the type of categorical features to 'category' \n",
"df[['minute',\n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']] = df[['minute',\n",
" 'position_name', \n",
" 'shot_technique_name', \n",
" 'shot_type_name', \n",
" 'number_of_players_opponents', \n",
" 'number_of_players_teammates', \n",
" 'shot_body_part_name']].astype('category')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 54,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Splitting the dataset into features (X) and the target variable (y)\n",
"y = pd.DataFrame(df['is_goal'])\n",
"X = df.drop(['is_goal'], axis=1)\n",
"\n",
"# Splitting the data into a training set and a test set\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)\n",
"\n",
"# Create cross-validation \n",
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 55,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2023-12-12 15:22:01 +01:00
"output_type": "stream",
"text": [
"Shots attempted in the training set: 27085\n",
"Goals scored in the training set: 3588\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"count_class_0, count_class_1 = y_train.value_counts()\n",
"\n",
"# Display the count of shots attempted in the training set\n",
"print('Shots attempted in the training set:', count_class_0)\n",
"\n",
"# Display the count of successful goals in the training set\n",
"print('Goals scored in the training set:', count_class_1)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 56,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Class imbalance in training data: 7.549\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Class imbalance in training data\n",
"scale_pos_weight = count_class_0 / count_class_1\n",
"print(f' Class imbalance in training data: {scale_pos_weight:.3f}')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training XGBoost model "
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 57,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Define the xgboost model\n",
"xgb_model = xgboost.XGBClassifier(enable_categorical=True, tree_method='hist', objective='binary:logistic')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 58,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Defining the hyper-parameter grid for XG Boost\n",
"param_grid_xgb = {'learning_rate': [0.01],\n",
" 'max_depth': [3],\n",
" 'n_estimators': [300],\n",
2023-12-12 15:22:01 +01:00
" 'scale_pos_weight': [1, scale_pos_weight]}"
]
},
{
"cell_type": "code",
"execution_count": 59,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Starting the timer\n",
"start_time = time.time()\n",
"\n",
"# Perform grid search with cross-validation\n",
"grid_xg = GridSearchCV(xgb_model, param_grid=param_grid_xgb, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Fit the best model on the entire training set\n",
"grid_xg.fit(X_train, y_train)\n",
"\n",
"# Take the best parameters for xgboost model\n",
"best_xgb_model = grid_xg.best_estimator_\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Stopping the timer\n",
"stop_time = time.time()\n",
"\n",
"# Training Time\n",
"xgb_training_time = stop_time - start_time"
]
},
{
"cell_type": "code",
"execution_count": 60,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-12 20:39:14 +01:00
"Best parameters: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 300, 'scale_pos_weight': 1}\n",
"Model Training Time: 16.402 seconds\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Print the best parameters and training time\n",
"print(\"Best parameters: \", grid_xg.best_params_)\n",
"print (f\"Model Training Time: {xgb_training_time:.3f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model evaluation"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 61,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi0AAAHFCAYAAAA+FskAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABLNElEQVR4nO3de3yP9f/H8efHTraxsc02cyZkTdIU01eOOc7oqFZjxSiihRJyqL4ZOkhEIubYOoiU2pcS5esscuygnJbNHGZs2NZcvz/8fL592rDpusyHx73bdbv5vK/39b7e1zX08nq9r+tjMwzDEAAAwDWuVElPAAAAoCgIWgAAgFMgaAEAAE6BoAUAADgFghYAAOAUCFoAAIBTIGgBAABOgaAFAAA4BYIWAADgFAhaUCzbtm3T448/rho1aqh06dIqU6aMbr/9do0fP17Hjx+39NxbtmxR8+bN5evrK5vNprfeesv0c9hsNo0ePdr0cS8nMTFRNptNNptNK1euLLDfMAzddNNNstlsatGixRWdY8qUKUpMTCzWMStXrrzonKwUGxtrvx+X2mJjY//Refbt2yebzVbs+3IphmEoKSlJzZo1U2BgoEqXLq3KlSurXbt2mjFjxhWNeSU/O+B6ZOM1/iiq6dOnq2/fvqpbt6769u2r0NBQ5eXladOmTZo+fboaNGigRYsWWXb+hg0bKjs7WxMnTlT58uVVvXp1BQcHm3qOdevWqXLlyqpcubKp415OYmKiHn/8cZUtW1ZdunTR3LlzHfavXLlSLVu2VNmyZXX77bdfURARFhamgICAYh178uRJ7dq1S6GhofLx8Sn2Oa/Ub7/9piNHjtg///DDD+rXr5/GjBmjli1b2tsrVKigWrVqXfF5cnJytGXLFtWqVUsVKlT4R3O+4IUXXtC4ceMUFxenyMhIlS1bVvv379eKFSuUkZGhzz//vNhjXsnPDrguGUARrFmzxnBxcTHat29vnD17tsD+nJwc47PPPrN0Dq6ursZTTz1l6TlKyqxZswxJRq9evQxPT08jMzPTYf9jjz1mREREGLfccovRvHnzKzpHcY7Nzc018vLyrug8Vvj2228NScbHH398yX6nT582zp07d5VmVfj5PTw8jO7duxe6Pz8//4rG/Sc/d+B6QnkIRTJmzBjZbDa999578vDwKLDf3d1dUVFR9s/nzp3T+PHjdfPNN8vDw0OBgYHq3r27UlJSHI5r0aKFwsLCtHHjRjVr1kxeXl6qWbOmxo4dq3Pnzkn6X+nkzz//1NSpU+2lAUkaPXq0/dd/deGYffv22dtWrFihFi1ayN/fX56enqpataruv/9+nT592t6nsPLQjh071KVLF5UvX16lS5fWbbfdptmzZzv0uVBG+eCDDzR8+HCFhITIx8dHbdq00c8//1y0myzpkUcekSR98MEH9rbMzEwtXLhQTzzxRKHHvPTSS2rcuLH8/Pzk4+Oj22+/Xe+//76MvyRRq1evrp07d2rVqlX2+1e9enWHuc+dO1eDBg1SpUqV5OHhoT179hQoDx09elRVqlRR06ZNlZeXZx9/165d8vb2VkxMTJGv9Z+68DNetmyZnnjiCVWoUEFeXl7KycnRnj179Pjjj6t27dry8vJSpUqV1LlzZ23fvt1hjMLKQxd+T+3cuVOPPPKIfH19FRQUpCeeeEKZmZmXnFN2drZycnJUsWLFQveXKuX4V25ubq7+/e9/2/+cVKhQQY8//rhDlulSPzvgRkPQgsvKz8/XihUrFB4eripVqhTpmKeeekpDhgzRPffcoyVLluiVV15RcnKymjZtqqNHjzr0TUtL06OPPqrHHntMS5YsUYcOHTR06FDNmzdPktSpUyetXbtWkvTAAw9o7dq19s9FtW/fPnXq1Enu7u6aOXOmkpOTNXbsWHl7eys3N/eix/38889q2rSpdu7cqbfffluffvqpQkNDFRsbq/HjxxfoP2zYMO3fv18zZszQe++9p19//VWdO3dWfn5+kebp4+OjBx54QDNnzrS3ffDBBypVqpS6det20Wvr06ePPvroI3366ae677771L9/f73yyiv2PosWLVLNmjXVsGFD+/37eylv6NChOnDggN599119/vnnCgwMLHCugIAAJSUlaePGjRoyZIgk6fTp03rwwQdVtWpVvfvuu0W6TjM98cQTcnNz09y5c/XJJ5/Izc1Nhw4dkr+/v8aOHavk5GS98847cnV1VePGjYscRN5///2qU6eOFi5cqBdeeEELFizQs88+e8ljAgICdNNNN2nKlCl688039dNPPzkEj3917tw5denSRWPHjlV0dLSWLl2qsWPHavny5WrRooXOnDkjqWg/O+CGUdKpHlz70tLSDEnGww8/XKT+u3fvNiQZffv2dWhfv369IckYNmyYva158+aGJGP9+vUOfUNDQ4127do5tEky+vXr59A2atQoo7DfxhfKLXv37jUMwzA++eQTQ5KxdevWS85dkjFq1Cj754cfftjw8PAwDhw44NCvQ4cOhpeXl3HixAnDMP5XvujYsaNDv48++siQZKxdu/aS570w340bN9rH2rFjh2EYhnHHHXcYsbGxhmFcvkyQn59v5OXlGS+//LLh7+/vUCq52LEXznf33XdfdN+3337r0D5u3DhDkrFo0SKjR48ehqenp7Ft27ZLXuM/UVh56MI9u1gp5q/+/PNPIzc316hdu7bx7LPP2tv37t1rSDJmzZplb7vwe2r8+PEOY/Tt29coXbr0ZctPGzZsMKpWrWpIMiQZZcuWNSIjI405c+Y4HPvBBx8YkoyFCxc6HL9x40ZDkjFlyhR7G+Uh4DwyLTDdt99+K0kFnuy48847Va9ePX3zzTcO7cHBwbrzzjsd2m699Vbt37/ftDnddtttcnd3V+/evTV79mz9/vvvRTpuxYoVat26dYEMU2xsrE6fPl0g4/PXEpl0/jokFetamjdvrlq1amnmzJnavn27Nm7ceNHS0IU5tmnTRr6+vnJxcZGbm5tGjhypY8eOKT09vcjnvf/++4vc97nnnlOnTp30yCOPaPbs2Zo0aZLq169/2eP+/PNPh80w4TmAwub9559/asyYMQoNDZW7u7tcXV3l7u6uX3/9Vbt37y7SuIX9LM+ePXvZe3rHHXdoz549Sk5O1rBhwxQREaFvvvlG3bt3V1RUlP2av/jiC5UrV06dO3d2uCe33XabgoODWXQLFIKgBZcVEBAgLy8v7d27t0j9jx07JkmF1vVDQkLs+y/w9/cv0M/Dw8OeHjdDrVq19PXXXyswMFD9+vVTrVq1VKtWLU2cOPGSxx07duyi13Fh/1/9/VourP8pzrXYbDY9/vjjmjdvnt59913VqVNHzZo1K7Tvhg0b1LZtW0nnn+7673//q40bN2r48OHFPu/F1mFcbI6xsbE6e/asgoODi7SWZd++fXJzc3PYVq1aVeRzXkxh8x44cKBGjBihrl276vPPP9f69eu1ceNGNWjQoMj35J/8LN3c3NSuXTu9+uqr+s9//qODBw+qRYsW+uKLL/TVV19Jkg4fPqwTJ07I3d29wH1JS0srUEYFILmW9ARw7XNxcVHr1q311VdfKSUl5bKPA1/4yz41NbVA30OHDikgIMC0uZUuXVrS+UdX/7pAuLC/8Js1a6ZmzZopPz9fmzZt0qRJkxQfH6+goCA9/PDDhY7v7++v1NTUAu2HDh2SJFOv5a9iY2M1cuRIvfvuu3r11Vcv2i8pKUlubm764osv7PdCkhYvXlzscxa2oPliUlNT1a9fP912223auXOnBg8erLfffvuSx4SEhGjjxo0ObXXr1i32PP+usHnPmzdP3bt315gxYxzajx49qnLlyv3jcxaXv7+/4uPjtXLlSu3YsUMdO3ZUQECA/P39lZycXOgxZcuWvcqzBK59ZFpQJEOHDpVhGIqLiyt04WpeXp79/ROtWrWSJPtC2gs2btyo3bt3q3Xr1qbN68JTFNu2bXNov9S7MFxcXNS4cWO98847ks6/A+RiWrdurRUrVtiDlAvmzJkjLy8vNWnS5ApnfmmVKlXSc889p86dO6tHjx4X7We
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on training set\n",
"y_pred_train = best_xgb_model.predict(X_train)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Training Data\n",
"cm_train_xg = confusion_matrix(y_train, y_pred_train)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_train_xg, annot=True, cmap='BuPu', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Train Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test set"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 62,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAHFCAYAAADCA+LKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABI00lEQVR4nO3df3zN9f//8fuxH8c222GbbSaERn63qBlvIb/D8u4HpRYlhGhFeS/v0DvvLSoqIj+bpNSn8Ebam1J715sxssqP9APJO/Nzhpltzev7h69Tx8bZdF5es27XLq/LxXm9Hud5nufkx2OPx/P5OjbDMAwBAABYqJLVEwAAACAhAQAAliMhAQAAliMhAQAAliMhAQAAliMhAQAAliMhAQAAliMhAQAAliMhAQAAliMhwWX7+uuv9eCDD6pu3bqqXLmyqlSpohtvvFFTpkzRsWPHTH3trVu3qn379nI4HLLZbHr55Zc9/ho2m00TJ070+LjupKSkyGazyWaz6bPPPit23TAMXXfddbLZbOrQocNlvcbMmTOVkpJSpud89tlnF52TmQYOHOj8PC51DBw40COv9/bbb5fp91NhYaFmz56tm266ScHBwfL391edOnV0++23a9myZZc1h6SkJC1fvvyyngtcrWzcOh6XY+7cuRo+fLgaNmyo4cOHq3HjxiosLNTmzZs1d+5ctWjR4rL/Mi6N6Oho5ebm6pVXXlG1atV07bXXKiIiwqOvkZ6ermuuuUbXXHONR8d1JyUlRQ8++KACAwN1++23a9GiRS7XP/vsM3Xs2FGBgYG68cYbLytBaNq0qUJDQ8v03BMnTmjHjh1q3LixgoKCyvyal+vHH3/U4cOHnY+//PJLjRgxQklJSerYsaPzfPXq1VW/fv0//Hq9evXStm3btHfv3lLF33PPPVq6dKkSEhLUoUMH2e127d69W6mpqapevbpef/31Ms+hSpUquuuuu8qcNAJXM2+rJ4Crz4YNGzRs2DB16dJFy5cvl91ud17r0qWLRo8erdTUVFPnsG3bNg0ePFg9evQw7TVat25t2til0a9fPy1evFivvfaaSwIwf/58xcbG6sSJE1dkHoWFhbLZbAoKCrLkM6lfv75LonHmzBlJUlRUlOX/j/bs2aN3331X48eP17PPPus836lTJw0ePFhnz561cHbA1YWWDcosKSlJNptNc+bMcUlGzvP19VVcXJzz8dmzZzVlyhRdf/31stvtCgsL0wMPPKD9+/e7PK9Dhw5q2rSpMjIy1K5dO/n7+6tevXp6/vnnnX+xn29n/Prrr5o1a5azXC9JEydOdP76984/5/c/8a5bt04dOnRQSEiI/Pz8VLt2bd155506ffq0M6akls22bdt0++23q1q1aqpcubJuuOEGLVy40CXmfGvjnXfe0bhx4xQZGamgoCB17txZu3btKt2HLOnee++VJL3zzjvOczk5Ofrggw/00EMPlficZ599VjExMQoODlZQUJBuvPFGzZ8/X78vhF577bXavn270tLSnJ/ftdde6zL3RYsWafTo0apZs6bsdrt++OGHYi2bI0eOqFatWmrTpo0KCwud4+/YsUMBAQGKj48v9Xv1hI8//lidOnVSUFCQ/P391bZtW33yyScuMYcPH9aQIUNUq1Yt2e12Va9eXW3bttXHH38s6dzvwQ8//FA//fSTSzvoYo4ePSpJqlGjRonXK1Vy/Sv2xIkTGjNmjOrWrStfX1/VrFlTCQkJys3NdcbYbDbl5uZq4cKFzte/3NYccDUhIUGZFBUVad26dWrZsqVq1apVqucMGzZMY8eOVZcuXbRixQo999xzSk1NVZs2bXTkyBGX2KysLN133326//77tWLFCvXo0UOJiYl66623JEk9e/bUhg0bJEl33XWXNmzY4HxcWnv37lXPnj3l6+urBQsWKDU1Vc8//7wCAgJUUFBw0eft2rVLbdq00fbt2/Xqq69q6dKlaty4sQYOHKgpU6YUi3/66af1008/ad68eZozZ46+//579e7dW0VFRaWaZ1BQkO666y4tWLDAee6dd95RpUqV1K9fv4u+t6FDh+q9997T0qVLdccdd2jkyJF67rnnnDHLli1TvXr1FB0d7fz8LmyvJSYmat++fXr99de1cuVKhYWFFXut0NBQLVmyRBkZGRo7dqwk6fTp07r77rtVu3bty2pVXK633npLXbt2VVBQkBYuXKj33ntPwcHB6tatm0tSEh8fr+XLl2v8+PFas2aN5s2bp86dOzsTi5kzZ6pt27aKiIhwfjaX+v3VqFEjVa1aVc8++6zmzJlzyTbP6dOn1b59ey1cuFCjRo3SRx99pLFjxyolJUVxcXHOpHHDhg3y8/PTbbfd5nz9mTNneuaDAsozAyiDrKwsQ5Jxzz33lCp+586dhiRj+PDhLuc3btxoSDKefvpp57n27dsbkoyNGze6xDZu3Njo1q2byzlJxogRI1zOTZgwwSjpt/Qbb7xhSDL27NljGIZhvP/++4YkIzMz85Jzl2RMmDDB+fiee+4x7Ha7sW/fPpe4Hj16GP7+/sbx48cNwzCMTz/91JBk3HbbbS5x7733niHJ2LBhwyVf9/x8MzIynGNt27bNMAzDuOmmm4yBAwcahmEYTZo0Mdq3b3/RcYqKiozCwkLjH//4hxESEmKcPXvWee1izz3/erfccstFr3366acu5ydPnmxIMpYtW2YMGDDA8PPzM77++utLvsc/4vw8/u///s8wDMPIzc01goODjd69e7vEFRUVGS1atDBuvvlm57kqVaoYCQkJlxy/Z8+eRp06dUo9nw8//NAIDQ01JBmSjJCQEOPuu+82VqxY4RKXnJxsVKpUycjIyHA5f/734+rVq53nAgICjAEDBpR6DkBFQIUEpvr0008lqdgOiJtvvlmNGjUqVlKPiIjQzTff7HKuefPm+umnnzw2pxtuuEG+vr4aMmSIFi5cqN27d5fqeevWrVOnTp2KVYYGDhyo06dPF/tJ+vdtK+nc+5BUpvfSvn171a9fXwsWLNA333yjjIyMi7Zrzs+xc+fOcjgc8vLyko+Pj8aPH6+jR4/q0KFDpX7dO++8s9SxTz75pHr27Kl7771XCxcu1PTp09WsWTO3z/v1119dDuMy19evX79ex44d04ABA1zGO3v2rLp3766MjAxnS+Tmm29WSkqKJk2apPT0dJdW0+W67bbbtG/fPi1btkxjxoxRkyZNtHz5csXFxenRRx91xq1atUpNmzbVDTfc4DLPbt26WbJ7CShvSEhQJqGhofL399eePXtKFX+pHntkZKTz+nkhISHF4ux2u/Ly8i5jtiWrX7++Pv74Y4WFhWnEiBHORZOvvPLKJZ939OjRi76P89d/78L3cn69TVnei81m04MPPqi33npLr7/+uho0aKB27dqVGLtp0yZ17dpV0rldUP/973+VkZGhcePGlfl1L7Ym4mJzHDhwoM6cOaOIiIhSrR3Zu3evfHx8XI60tLRSv+bvHTx4UNK5Ft6FY06ePFmGYTi3ob/77rsaMGCA5s2bp9jYWAUHB+uBBx5QVlbWZb32eX5+furTp49eeOEFpaWl6YcfflDjxo312muvafv27c55fv3118XmGBgYKMMwirUvgT8bdtmgTLy8vNSpUyd99NFH2r9/v9stsef/UT5w4ECx2F9++UWhoaEem1vlypUlSfn5+S6LbUv6i75du3Zq166dioqKtHnzZk2fPl0JCQkKDw/XPffcU+L4ISEhOnDgQLHzv/zyiyR59L383sCBAzV+/Hi9/vrr+uc//3nRuCVLlsjHx0erVq1yfhaSLut+FpdayHmhAwcOaMSIEbrhhhu0fft2jRkzRq+++uolnxMZGamMjAyXcw0bNizzPKXfPvfp06dfdNdNeHi4M/bll1/Wyy+/rH379mnFihX629/+pkOHDnl0Z1jt2rU1ZMgQJSQkaPv27WrSpIlCQ0Pl5+fnsiaopPcB/FlRIUGZJSYmyjAMDR48uMRFoIWFhVq5cqUk6dZbb5Uk56LU8zIyMrRz50516tTJY/M6v1Pk66+/djl/fi4l8fLyUkxMjF5
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Evaluate the model on test set\n",
"y_pred_test = best_xgb_model.predict(X_test)\n",
2023-12-12 15:22:01 +01:00
"\n",
"# Confusion Matrix for Testig Data\n",
"cm_test_xgb = confusion_matrix(y_test, y_pred_test)\n",
2023-12-12 15:22:01 +01:00
"ax = sns.heatmap(cm_test_xgb, annot=True, cmap='Blues', fmt='g', linewidth=1.5)\n",
"ax.set_xlabel('Predicted')\n",
"ax.set_ylabel('Actual')\n",
"ax.set_title('Confusion Matrix - Test Set')\n",
"plt.show()"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 63,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The test dataset contains 7669 shots, with 914 of them being goals.\n"
2023-12-12 15:22:01 +01:00
]
}
],
"source": [
"# Number of goals in test set\n",
"print(f'The test dataset contains {len(y_test)} shots, with {y_test.sum()[\"is_goal\"]} of them being goals.')"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "markdown",
2023-12-12 15:22:01 +01:00
"metadata": {},
"source": [
"## Feature importance"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 64,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2MAAAHFCAYAAAB/6yHTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxNz/8H8Nfttt32XSEtaEVlS5GytFqyZ4t8CB/7nmwlW1lDdsm+R0IiVLYkCRFRJEu2Sit1687vj36dr+PeFvL5xMc8H48e3DlzZua8z61z586cORxCCAFFURRFURRFURT1rxKr7wZQFEVRFEVRFEX9iWhnjKIoiqIoiqIoqh7QzhhFURRFURRFUVQ9oJ0xiqIoiqIoiqKoekA7YxRFURRFURRFUfWAdsYoiqIoiqIoiqLqAe2MURRFURRFURRF1QPaGaMoiqIoiqIoiqoHtDNGURRFURRFURRVD2hnjKIoivpl7d69GxwOR+TPrFmz/pE6U1JS4Ovri4yMjH+k/LrIyMgAh8PB7t2767spPywiIgK+vr713QyKoqhfgnh9N4CiKIqiahISEgIjIyNWWsOGDf+RulJSUrB48WLY2dlBV1f3H6njR2lpaSEuLg5Nmzat76b8sIiICGzatIl2yCiKokA7YxRFUdRvoEWLFmjbtm19N6NO+Hw+OBwOxMV//NIrJSWFDh06/MRW/XuKi4shIyNT382gKIr6pdBpihRFUdRv78iRI7CysoKsrCzk5OTg6OiIpKQkVp7bt29j8ODB0NXVBY/Hg66uLoYMGYIXL14weXbv3o2BAwcCALp06cJMiaycFqirqwsPDw+h+u3s7GBnZ8e8jomJAYfDwb59+zBz5kw0atQIUlJSSEtLAwBcvHgR3bp1g4KCAmRkZNCxY0dcunSpxuMUNU3R19cXHA4H9+/fx8CBA6GoqAgVFRXMmDEDZWVlSE1NhZOTE+Tl5aGrq4uVK1eyyqxs6/79+zFjxgxoamqCx+PB1tZWKIYAEB4eDisrK8jIyEBeXh729vaIi4tj5als0507dzBgwAAoKyujadOm8PDwwKZNmwCANeW0ckropk2b0LlzZ2hoaEBWVhYtW7bEypUrwefzheLdokULJCQkwMbGBjIyMtDX14e/vz8EAgEr76dPnzBz5kzo6+tDSkoKGhoacHFxwePHj5k8paWlWLp0KYyMjCAlJQV1dXWMGjUKHz58qPGcUBRF1QXtjFEURVG/vPLycpSVlbF+Ki1fvhxDhgyBiYkJjh49in379qGgoAA2NjZISUlh8mVkZMDQ0BCBgYE4f/48AgICkJWVhXbt2uHjx48AgB49emD58uUAKjoGcXFxiIuLQ48ePX6o3d7e3sjMzMTWrVtx+vRpaGhoYP/+/XBwcICCggL27NmDo0ePQkVFBY6OjrXqkFVl0KBBMDMzQ2hoKDw9PbFu3TpMnz4dffr0QY8ePXDy5El07doVXl5eOHHihND+8+bNw7Nnz7Bz507s3LkTb968gZ2dHZ49e8bkOXjwIFxdXaGgoIBDhw4hODgYubm5sLOzw7Vr14TK7NevH5o1a4Zjx45h69atWLhwIQYMGAAATGzj4uKgpaUFAEhPT8fQoUOxb98+nDlzBqNHj8aqVaswbtw4obLfvn2LYcOGYfjw4QgPD4ezszO8vb2xf/9+Jk9BQQE6deqEbdu2YdSoUTh9+jS2bt0KAwMDZGVlAQAEAgFcXV3h7++PoUOH4uzZs/D390dUVBTs7Ozw+fPnHz4nFEVRNSIURVEU9YsKCQkhAET+8Pl8kpmZScTFxcnkyZNZ+xUUFBBNTU0yaNCgKssuKysjhYWFRFZWlqxfv55JP3bsGAFAoqOjhfbR0dEhI0eOFEq3tbUltra2zOvo6GgCgHTu3JmVr6ioiKioqJBevXqx0svLy4mZmRlp3759NdEg5Pnz5wQACQkJYdJ8fHwIALJmzRpWXnNzcwKAnDhxgknj8/lEXV2d9OvXT6itrVu3JgKBgEnPyMggEhISZMyYMUwbGzZsSFq2bEnKy8uZfAUFBURDQ4NYW1sLtWnRokVCxzBx4kRSm48f5eXlhM/nk7179xIul0tycnKYbba2tgQAiY+PZ+1jYmJCHB0dmdd+fn4EAImKiqqynkOHDhEAJDQ0lJWekJBAAJDNmzfX2FaKoqgfRUfGKIqiqF/e3r17kZCQwPoRFxfH+fPnUVZWhhEjRrBGzaSlpWFra4uYmBimjMLCQnh5eaFZs2YQFxeHuLg45OTkUFRUhEePHv0j7e7fvz/r9Y0bN5CTk4ORI0ey2isQCODk5ISEhAQUFRX9UF09e/ZkvTY2NgaHw4GzszOTJi4ujmbNmrGmZlYaOnQoOBwO81pHRwfW1taIjo4GAKSmpuLNmzdwd3eHmNj/Pj7Iycmhf//+uHnzJoqLi6s9/pokJSWhd+/eUFVVBZfLhYSEBEaMGIHy8nI8efKElVdTUxPt27dnpbVq1Yp1bOfOnYOBgQG6d+9eZZ1nzpyBkpISevXqxTon5ubm0NTUZL2HKIqifja6gAdFURT1yzM2Nha5gMe7d+8AAO3atRO539edhqFDh+LSpUtYuHAh2rVrBwUFBXA4HLi4uPxjU9Eqp999297KqXqi5OTkQFZW9rvrUlFRYb2WlJSEjIwMpKWlhdLz8/OF9tfU1BSZdu/ePQBAdnY2AOFjAipWthQIBMjNzWUt0iEqb1UyMzNhY2MDQ0NDrF+/Hrq6upCWlsatW7cwceJEoXOkqqoqVIaUlBQr34cPH9CkSZNq63337h0+ffoESUlJkdsrp7BSFEX9E2hnjKIoivptqampAQCOHz8OHR2dKvPl5eXhzJkz8PHxwdy5c5n0kpIS5OTk1Lo+aWlplJSUCKV//PiRacvXvh5p+rq9GzdurHJVxAYNGtS6PT/T27dvRaZVdnoq/6281+prb968gZiYGJSVlVnp3x5/dcLCwlBUVIQTJ06wzuXdu3drXca31NXV8erVq2rzqKmpQVVVFZGRkSK3y8vL/3D9FEVRNaGdMYqiKOq35ejoCHFxcaSnp1c7JY7D4YAQAikpKVb6zp07UV5ezkqrzCNqtExXVxf3799npT158gSpqakiO2Pf6tixI5SUlJCSkoJJkybVmP/fdOjQIcyYMYPpQL148QI3btzAiBEjAACGhoZo1KgRDh48iFmzZjH5ioqKEBoayqywWJOv48vj8Zj0yvK+PkeEEOzYseOHj8nZ2RmLFi3C5cuX0bVrV5F5evbsicOHD6O8vByWlpY/XBdFUdSPoJ0xiqIo6relq6sLPz8/zJ8/H8+ePYOTkxOUlZXx7t073Lp1C7Kysli8eDEUFBTQuXNnrFq1CmpqatDV1UVsbCyCg4OhpKTEKrNFixYAgO3bt0NeXh7S0tLQ09ODqqoq3N3dMXz4cEyYMAH9+/fHixcvsHLlSqirq9eqvXJycti4cSNGjhyJnJwcDBgwABoaGvjw4QPu3buHDx8+YMuWLT87TLXy/v179O3bF56ensjLy4OPjw+kpaXh7e0NoGLK58qVKzFs2DD07NkT48aNQ0lJCVatWoVPnz7B39+/VvW0bNkSABAQEABnZ2dwuVy0atUK9vb2kJSUxJAhQzBnzhx8+fIFW7ZsQW5u7g8f07Rp03DkyBG4urpi7ty5aN++PT5//ozY2Fj07NkTXbp0weDBg3HgwAG4uLhg6tSpaN++PSQkJPDq1StER0fD1dUVffv2/eE2UBRFVYcu4EFRFEX91ry9vXH8+HE8efIEI0eOhKOjI+bMmYMXL16gc+fOTL6DBw+iS5cumDNnDvr164fbt28jKioKioqKrPL09PQQGBiIe/fuwc7ODu3atcPp06cBVNx3tnLlSpw/fx49e/bEli1bsGXLFhgYGNS6vcOHD0d0dDQKCwsxbtw4dO/eHVOnTsWdO3fQrVu3nxOUH7B8+XLo6Ohg1KhR+Ouvv6ClpYXo6Gg0bdqUyTN06FCEhYUhOzsbbm5uGDVqFBQUFBAdHY1OnTrVqp6hQ4dizJgx2Lx5M6ysrNCuXTu8efMGRkZGCA0NRW5
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Gain\n",
"xgboost.plot_importance(best_xgb_model, importance_type='gain', xlabel='Gain', max_num_features=20)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 65,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvwAAAHFCAYAAACOxS13AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVyO2f/48ddd0l6SJCayVyT7kK2GSg2yJ0uyhLFvI8aWZcjeYDDMKGOZzAyDoRkTypqQ9YOxRGJk7LKMtNy/P/p1fd1aEVnez8fjfnCd61znOtf7vrnPfa5znaNSq9VqhBBCCCGEEB8krcKugBBCCCGEEOLNkQa/EEIIIYQQHzBp8AshhBBCCPEBkwa/EEIIIYQQHzBp8AshhBBCCPEBkwa/EEIIIYQQHzBp8AshhBBCCPEBkwa/EEIIIYQQHzBp8AshhBBCCPEBkwa/EEKId1ZoaCgqlSrb1+jRo9/IOc+cOUNgYCDx8fFvpPzXER8fj0qlIjQ0tLCr8srCw8MJDAws7GoI8VEpUtgVEEIIIfISEhKCra2tRlrp0qXfyLnOnDnDlClTcHZ2xsbG5o2c41VZWVkRHR1NxYoVC7sqryw8PJxvv/1WGv1CvEXS4BdCCPHOq169OnXr1i3saryWlJQUVCoVRYq8+levrq4uDRo0KMBavT1PnjzBwMCgsKshxEdJhvQIIYR4761fv56GDRtiaGiIkZER7u7uHDt2TCPPkSNH6NKlCzY2Nujr62NjY4OPjw9XrlxR8oSGhtKpUycAXFxclOFDmUNobGxs8PPzy3J+Z2dnnJ2dle2oqChUKhWrV69m1KhRlClTBl1dXS5evAjAjh07aN68OSYmJhgYGNCoUSN27tyZ53VmN6QnMDAQlUrFyZMn6dSpE6amphQvXpyRI0eSmprKuXPnaNmyJcbGxtjY2DB79myNMjPrumbNGkaOHEmpUqXQ19enWbNmWWIIsGXLFho2bIiBgQHGxsa4uroSHR2tkSezTkePHqVjx46YmZlRsWJF/Pz8+PbbbwE0hmdlDp/69ttvadq0KSVLlsTQ0BAHBwdmz55NSkpKlnhXr16dw4cP06RJEwwMDKhQoQJBQUGkp6dr5L1//z6jRo2iQoUK6OrqUrJkSTw9Pfn777+VPM+ePWP69OnY2tqiq6uLhYUFvXr14tatW3m+J0K8D6TBL4QQ4p2XlpZGamqqxivTjBkz8PHxwd7enp9//pnVq1fz8OFDmjRpwpkzZ5R88fHxVK1aleDgYLZv386sWbNITEykXr163L59G4DPP/+cGTNmABmNz+joaKKjo/n8889fqd7jxo0jISGBZcuW8fvvv1OyZEnWrFmDm5sbJiYmrFq1ip9//pnixYvj7u6er0Z/Tjp37oyjoyMbNmzA39+fBQsWMGLECNq2bcvnn3/Ob7/9xmeffUZAQAAbN27McvxXX33FpUuX+P777/n++++5fv06zs7OXLp0Scmzbt06vLy8MDEx4aeffuKHH37g3r17ODs7s2/fvixltm/fnkqVKvHLL7+wbNkyJk6cSMeOHQGU2EZHR2NlZQVAXFwcXbt2ZfXq1WzdupU+ffowZ84c+vfvn6XsGzdu0K1bN7p3786WLVvw8PBg3LhxrFmzRsnz8OFDGjduzHfffUevXr34/fffWbZsGVWqVCExMRGA9PR0vLy8CAoKomvXrmzbto2goCAiIiJwdnbmv//+e+X3RIh3hloIIYR4R4WEhKiBbF8pKSnqhIQEdZEiRdRDhgzROO7hw4fqUqVKqTt37pxj2ampqepHjx6pDQ0N1d98842S/ssvv6gBdWRkZJZjypUrp+7Zs2eW9GbNmqmbNWumbEdGRqoBddOmTTXyPX78WF28eHF169atNdLT0tLUjo6O6vr16+cSDbX68uXLakAdEhKipE2ePFkNqOfNm6eRt2bNmmpAvXHjRiUtJSVFbWFhoW7fvn2WutauXVudnp6upMfHx6t1dHTUffv2VepYunRptYODgzotLU3J9/DhQ3XJkiXVTk5OWeo0adKkLNcwaNAgdX6aH2lpaeqUlBT1jz/+qNbW1lbfvXtX2desWTM1oI6JidE4xt7eXu3u7q5sT506VQ2oIyIicjzPTz/9pAbUGzZs0Eg/fPiwGlAvWbIkz7oK8a6THn4hhBDvvB9//JHDhw9rvIoUKcL27dtJTU3F19dXo/dfT0+PZs2aERUVpZTx6NEjAgICqFSpEkWKFKFIkSIYGRnx+PFjzp49+0bq3aFDB43tAwcOcPfuXXr27KlR3/T0dFq2bMnhw4d5/PjxK52rVatWGtt2dnaoVCo8PDyUtCJFilCpUiWNYUyZunbtikqlUrbLlSuHk5MTkZGRAJw7d47r16/To0cPtLT+r/lgZGREhw4dOHjwIE+ePMn1+vNy7Ngx2rRpg7m5Odra2ujo6ODr60taWhrnz5/XyFuqVCnq16+vkVajRg2Na/vjjz+oUqUKLVq0yPGcW7dupVixYrRu3VrjPalZsyalSpXS+AwJ8b6Sh3aFEEK88+zs7LJ9aPfff/8FoF69etke93zDtGvXruzcuZOJEydSr149TExMUKlUeHp6vrFhG5lDVV6sb+awluzcvXsXQ0PDlz5X8eLFNbaLFi2KgYEBenp6WdKTkpKyHF+qVKls006cOAHAnTt3gKzXBBkzJqWnp3Pv3j2NB3Ozy5uThIQEmjRpQtWqVfnmm2+wsbFBT0+PQ4cOMWjQoCzvkbm5eZYydHV1NfLdunWLsmXL5nref//9l/v371O0aNFs92cO9xLifSYNfiGEEO+tEiVKAPDrr79Srly5HPM9ePCArVu3MnnyZMaOHaukJycnc/fu3XyfT09Pj+Tk5Czpt2/fVuryvOd7zJ+v76JFi3KcbcfS0jLf9SlIN27cyDYts2Gd+Wfm2PfnXb9+HS0tLczMzDTSX7z+3GzatInHjx+zceNGjffy+PHj+S7jRRYWFly7di3XPCVKlMDc3Jw///wz2/3GxsavfH4h3hXS4BdCCPHecnd3p0iRIsTFxeU6fESlUqFWq9HV1dVI//7770lLS9NIy8yTXa+/jY0NJ0+e1Eg7f/48586dy7bB/6JGjRpRrFgxzpw5w+DBg/PM/zb99NNPjBw5UmmkX7lyhQMHDuDr6wtA1apVKVOmDOvWrWP06NFKvsePH7NhwwZl5p68PB9ffX19JT2zvOffI7VazYoVK175mjw8PJg0aRK7du3is88+yzZPq1atCAsLIy0tjU8//fSVzyXEu0wa/EIIId5bNjY2TJ06lfHjx3Pp0iVatmyJmZkZ//77L4cOHcLQ0JApU6ZgYmJC06ZNmTNnDiVKlMDGxobdu3fzww8/UKxYMY0yq1evDsDy5csxNjZGT0+P8uXLY25uTo8ePejevTsDBw6kQ4cOXLlyhdmzZ2NhYZGv+hoZGbFo0SJ69uzJ3bt36dixIyVLluTWrVucOHGCW7dusXTp0oIOU77cvHmTdu3a4e/vz4MHD5g8eTJ6enqMGzcOyBgeNXv2bLp160arVq3o378/ycnJzJkzh/v37xMUFJSv8zg4OAAwa9YsPDw80NbWpkaNGri6ulK0aFF8fHwYM2YMT58+ZenSpdy7d++Vr2n48OGsX78eLy8vxo4dS/369fnvv//YvXs3rVq1wsXFhS5durB27Vo8PT0ZNmwY9evXR0dHh2vXrhEZGYmXlxft2rV75ToI8S6Qh3aFEEK818aNG8evv/7K+fPn6dmzJ+7u7owZM4YrV67QtGlTJd+6detwcXFhzJgxtG/fniNHjhAREYGpqalGeeXLlyc4OJgTJ07g7OxMvXr1+P3334GM5wBmz57N9u3badWqFUuXLmXp0qVUqVIl3/Xt3r07kZGRPHr0iP79+9OiRQuGDRvG0aNHad68ecEE5RXMmDGDcuXK0atXL3r37o2VlRWRkZEaq/p27dqVTZs2cefOHby9venVqxcmJiZERkbSuHHjfJ2na9eu9O3blyVLltCwYUPq1avH9evXsbW1ZcOGDdy7d4/27dszZMgQatasycKFC1/5moyNjdm3bx9
2023-12-12 15:22:01 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importance with Weight\n",
"xgboost.plot_importance(best_xgb_model, importance_type='weight', xlabel='Weight', max_num_features=30)\n",
2023-12-12 15:22:01 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 66,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Calculating MAE, RMSE and R2 for training and test sets \n",
"mae_train = mean_absolute_error(y_train, y_pred_train)\n",
"rmse_train = mean_squared_error(y_train, y_pred_train, squared=False)\n",
"r2_train = r2_score(y_train, y_pred_train)\n",
"\n",
"mae_test = mean_absolute_error(y_test, y_pred_test)\n",
"rmse_test = mean_squared_error(y_test, y_pred_test, squared=False)\n",
"r2_test = r2_score(y_test, y_pred_test)"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 67,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_6e8cc_row0_col0, #T_6e8cc_row0_col1, #T_6e8cc_row0_col2, #T_6e8cc_row0_col3, #T_6e8cc_row0_col4, #T_6e8cc_row0_col5, #T_6e8cc_row0_col6 {\n",
2023-12-12 15:22:01 +01:00
" font-weight: bold;\n",
" border: 2.0px solid grey;\n",
" color: white;\n",
"}\n",
"</style>\n",
"<table id=\"T_6e8cc_\">\n",
2023-12-12 15:22:01 +01:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th class=\"col_heading level0 col0\" >Training MAE</th>\n",
" <th class=\"col_heading level0 col1\" >Training RMSE</th>\n",
" <th class=\"col_heading level0 col2\" >Training R2</th>\n",
" <th class=\"col_heading level0 col3\" >Testing MAE</th>\n",
" <th class=\"col_heading level0 col4\" >Testing RMSE</th>\n",
" <th class=\"col_heading level0 col5\" >Testing R2</th>\n",
" <th class=\"col_heading level0 col6\" >Training Time (mins)</th>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Model Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_6e8cc_level0_row0\" class=\"row_heading level0 row0\" >XG Boost</th>\n",
" <td id=\"T_6e8cc_row0_col0\" class=\"data row0 col0\" >0.10309</td>\n",
" <td id=\"T_6e8cc_row0_col1\" class=\"data row0 col1\" >0.32107</td>\n",
" <td id=\"T_6e8cc_row0_col2\" class=\"data row0 col2\" >0.00199</td>\n",
" <td id=\"T_6e8cc_row0_col3\" class=\"data row0 col3\" >0.10549</td>\n",
" <td id=\"T_6e8cc_row0_col4\" class=\"data row0 col4\" >0.32479</td>\n",
" <td id=\"T_6e8cc_row0_col5\" class=\"data row0 col5\" >-0.00488</td>\n",
" <td id=\"T_6e8cc_row0_col6\" class=\"data row0 col6\" >0.27337</td>\n",
2023-12-12 15:22:01 +01:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x265390f1dc0>"
2023-12-12 15:22:01 +01:00
]
},
"execution_count": 67,
2023-12-12 15:22:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Creating of dataframe of summary results\n",
"summary_df = pd.DataFrame({'Model Name':['XG Boost'],\n",
" 'Training MAE': mae_train, \n",
" 'Training RMSE': rmse_train,\n",
" 'Training R2':r2_train,\n",
" 'Testing MAE': mae_test, \n",
" 'Testing RMSE': rmse_test,\n",
" 'Testing R2':r2_test,\n",
" 'Training Time (mins)': xgb_training_time/60})\n",
2023-12-12 15:22:01 +01:00
"summary_df.set_index('Model Name', inplace=True)\n",
"\n",
2023-12-12 15:22:01 +01:00
"# Displaying summary of results\n",
"summary_df.style.format(precision =5).set_properties(**{'font-weight': 'bold',\n",
2023-12-12 15:22:01 +01:00
" 'border': '2.0px solid grey','color': 'white'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keeping the xgboost model"
2023-12-12 15:22:01 +01:00
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'objective': 'binary:logistic',\n",
" 'base_score': None,\n",
" 'booster': None,\n",
" 'callbacks': None,\n",
" 'colsample_bylevel': None,\n",
" 'colsample_bynode': None,\n",
" 'colsample_bytree': None,\n",
" 'device': None,\n",
" 'early_stopping_rounds': None,\n",
" 'enable_categorical': True,\n",
" 'eval_metric': None,\n",
" 'feature_types': None,\n",
" 'gamma': None,\n",
" 'grow_policy': None,\n",
" 'importance_type': None,\n",
" 'interaction_constraints': None,\n",
" 'learning_rate': 0.01,\n",
" 'max_bin': None,\n",
" 'max_cat_threshold': None,\n",
" 'max_cat_to_onehot': None,\n",
" 'max_delta_step': None,\n",
" 'max_depth': 3,\n",
" 'max_leaves': None,\n",
" 'min_child_weight': None,\n",
" 'missing': nan,\n",
" 'monotone_constraints': None,\n",
" 'multi_strategy': None,\n",
" 'n_estimators': 300,\n",
" 'n_jobs': None,\n",
" 'num_parallel_tree': None,\n",
" 'random_state': None,\n",
" 'reg_alpha': None,\n",
" 'reg_lambda': None,\n",
" 'sampling_method': None,\n",
" 'scale_pos_weight': 1,\n",
" 'subsample': None,\n",
" 'tree_method': 'hist',\n",
" 'validate_parameters': None,\n",
" 'verbosity': None}"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the model\n",
"best_xgb_model.save_model('xgboost.json')\n",
"\n",
"best_xgb_model.get_params()"
]
},
{
"cell_type": "code",
"execution_count": 69,
2023-12-12 15:22:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"# dump(best_xgb_model, 'xgboost.joblib') \n",
2023-12-12 15:22:01 +01:00
"\n",
"# # Load the model\n",
"# model = load('xgboost.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.9334918 , 0.06650815],\n",
" [0.9150994 , 0.0849006 ],\n",
" [0.9478227 , 0.05217729],\n",
" ...,\n",
" [0.77691543, 0.2230846 ],\n",
" [0.9318629 , 0.06813709],\n",
" [0.95634604, 0.04365399]], dtype=float32)"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_xgb_model.predict_proba(X)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"new_xgb_model = xgboost.XGBClassifier()\n",
"new_xgb_model.load_model('xgboost.json')"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.9334918 , 0.06650815],\n",
" [0.9150994 , 0.0849006 ],\n",
" [0.9478227 , 0.05217729],\n",
" ...,\n",
" [0.77691543, 0.2230846 ],\n",
" [0.9318629 , 0.06813709],\n",
" [0.95634604, 0.04365399]], dtype=float32)"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_xgb_model.predict_proba(X)"
2023-12-12 15:22:01 +01:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
2023-12-12 15:22:01 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}