{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n", "from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n", "import xgboost\n", "import time\n", "from joblib import dump, load\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load the data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('final_data_new.csv')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n", " 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n", " 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n", " 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n", " 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n", " 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n", " 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n", " 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n", " 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n", " 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n", " 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n", " 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n", " 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n", " 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n", " 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n", " 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n", " 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n", " 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n", " 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n", " 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n", " 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n", " 'y_player_teammate_Goalkeeper'],\n", " dtype='object')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | minute | \n", "position_name | \n", "shot_body_part_name | \n", "shot_technique_name | \n", "shot_type_name | \n", "shot_first_time | \n", "shot_one_on_one | \n", "shot_aerial_won | \n", "shot_open_goal | \n", "shot_follows_dribble | \n", "... | \n", "y_player_teammate_5 | \n", "y_player_teammate_6 | \n", "y_player_teammate_7 | \n", "y_player_teammate_8 | \n", "y_player_teammate_9 | \n", "y_player_teammate_10 | \n", "x_player_opponent_7 | \n", "y_player_opponent_7 | \n", "x_player_teammate_Goalkeeper | \n", "y_player_teammate_Goalkeeper | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2 | \n", "Left Center Forward | \n", "Right Foot | \n", "Half Volley | \n", "Open Play | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "26.6 | \n", "53.1 | \n", "NaN | \n", "NaN | \n", "
1 | \n", "5 | \n", "Left Back | \n", "Left Foot | \n", "Volley | \n", "Open Play | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "20.6 | \n", "32.8 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "23.8 | \n", "31.2 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "15 | \n", "Left Center Forward | \n", "Left Foot | \n", "Normal | \n", "Open Play | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "29.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "29.6 | \n", "55.3 | \n", "NaN | \n", "NaN | \n", "
3 | \n", "16 | \n", "Center Forward | \n", "Head | \n", "Normal | \n", "Open Play | \n", "False | \n", "False | \n", "True | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "26.7 | \n", "60.4 | \n", "NaN | \n", "NaN | \n", "
4 | \n", "18 | \n", "Right Center Forward | \n", "Right Foot | \n", "Normal | \n", "Open Play | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "27.9 | \n", "31.4 | \n", "33.4 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "16.9 | \n", "40.1 | \n", "NaN | \n", "NaN | \n", "
5 rows × 64 columns
\n", "\n", " | minute | \n", "position_name | \n", "shot_body_part_name | \n", "shot_technique_name | \n", "shot_type_name | \n", "shot_first_time | \n", "shot_one_on_one | \n", "shot_aerial_won | \n", "shot_open_goal | \n", "shot_follows_dribble | \n", "... | \n", "y_player_teammate_5 | \n", "y_player_teammate_6 | \n", "y_player_teammate_7 | \n", "y_player_teammate_8 | \n", "y_player_teammate_9 | \n", "y_player_teammate_10 | \n", "x_player_opponent_7 | \n", "y_player_opponent_7 | \n", "x_player_teammate_Goalkeeper | \n", "y_player_teammate_Goalkeeper | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2 | \n", "9 | \n", "3 | \n", "2 | \n", "3 | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "26.6 | \n", "53.1 | \n", "NaN | \n", "NaN | \n", "
1 | \n", "5 | \n", "7 | \n", "1 | \n", "6 | \n", "3 | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "20.6 | \n", "32.8 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "23.8 | \n", "31.2 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "15 | \n", "9 | \n", "1 | \n", "4 | \n", "3 | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "29.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "29.6 | \n", "55.3 | \n", "NaN | \n", "NaN | \n", "
3 | \n", "16 | \n", "3 | \n", "0 | \n", "4 | \n", "3 | \n", "False | \n", "False | \n", "True | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "26.7 | \n", "60.4 | \n", "NaN | \n", "NaN | \n", "
4 | \n", "18 | \n", "18 | \n", "3 | \n", "4 | \n", "3 | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "27.9 | \n", "31.4 | \n", "33.4 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "16.9 | \n", "40.1 | \n", "NaN | \n", "NaN | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
82816 | \n", "79 | \n", "0 | \n", "3 | \n", "2 | \n", "3 | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "30.9 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "30.8 | \n", "40.3 | \n", "NaN | \n", "NaN | \n", "
82817 | \n", "80 | \n", "20 | \n", "3 | \n", "4 | \n", "3 | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "60.2 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "31.9 | \n", "47.7 | \n", "NaN | \n", "NaN | \n", "
82818 | \n", "82 | \n", "0 | \n", "3 | \n", "4 | \n", "3 | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
82819 | \n", "84 | \n", "21 | \n", "3 | \n", "4 | \n", "3 | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
82820 | \n", "88 | \n", "8 | \n", "1 | \n", "2 | \n", "3 | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "20.0 | \n", "44.5 | \n", "NaN | \n", "NaN | \n", "
82821 rows × 64 columns
\n", "\n", " | Training MAE | \n", "Training RMSE | \n", "Training R2 | \n", "Testing MAE | \n", "Testing RMSE | \n", "Testing R2 | \n", "Training Time (mins) | \n", "
---|---|---|---|---|---|---|---|
Model Name | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
XG Boost | \n", "0.09646 | \n", "0.31058 | \n", "0.01446 | \n", "0.09919 | \n", "0.31494 | \n", "0.01918 | \n", "23.22242 | \n", "