{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV\n", "from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error, r2_score\n", "import xgboost\n", "import time\n", "from joblib import dump, load\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load the data" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('final_data.csv')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['minute', 'position_name', 'shot_body_part_name', 'shot_technique_name',\n", " 'shot_type_name', 'shot_first_time', 'shot_one_on_one',\n", " 'shot_aerial_won', 'shot_open_goal', 'shot_follows_dribble',\n", " 'shot_redirect', 'x1', 'y1', 'number_of_players_opponents',\n", " 'number_of_players_teammates', 'is_goal', 'angle', 'distance',\n", " 'x_player_opponent_Goalkeeper', 'x_player_opponent_8',\n", " 'x_player_opponent_1', 'x_player_opponent_2', 'x_player_opponent_3',\n", " 'x_player_teammate_1', 'x_player_opponent_4', 'x_player_opponent_5',\n", " 'x_player_opponent_6', 'x_player_teammate_2', 'x_player_opponent_9',\n", " 'x_player_opponent_10', 'x_player_opponent_11', 'x_player_teammate_3',\n", " 'x_player_teammate_4', 'x_player_teammate_5', 'x_player_teammate_6',\n", " 'x_player_teammate_7', 'x_player_teammate_8', 'x_player_teammate_9',\n", " 'x_player_teammate_10', 'y_player_opponent_Goalkeeper',\n", " 'y_player_opponent_8', 'y_player_opponent_1', 'y_player_opponent_2',\n", " 'y_player_opponent_3', 'y_player_teammate_1', 'y_player_opponent_4',\n", " 'y_player_opponent_5', 'y_player_opponent_6', 'y_player_teammate_2',\n", " 'y_player_opponent_9', 'y_player_opponent_10', 'y_player_opponent_11',\n", " 'y_player_teammate_3', 'y_player_teammate_4', 'y_player_teammate_5',\n", " 'y_player_teammate_6', 'y_player_teammate_7', 'y_player_teammate_8',\n", " 'y_player_teammate_9', 'y_player_teammate_10', 'x_player_opponent_7',\n", " 'y_player_opponent_7', 'x_player_teammate_Goalkeeper',\n", " 'y_player_teammate_Goalkeeper'],\n", " dtype='object')" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | minute | \n", "position_name | \n", "shot_body_part_name | \n", "shot_technique_name | \n", "shot_type_name | \n", "shot_first_time | \n", "shot_one_on_one | \n", "shot_aerial_won | \n", "shot_open_goal | \n", "shot_follows_dribble | \n", "... | \n", "y_player_teammate_5 | \n", "y_player_teammate_6 | \n", "y_player_teammate_7 | \n", "y_player_teammate_8 | \n", "y_player_teammate_9 | \n", "y_player_teammate_10 | \n", "x_player_opponent_7 | \n", "y_player_opponent_7 | \n", "x_player_teammate_Goalkeeper | \n", "y_player_teammate_Goalkeeper | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "Right Center Forward | \n", "Right Foot | \n", "Normal | \n", "Open Play | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
1 | \n", "5 | \n", "Right Center Forward | \n", "Left Foot | \n", "Normal | \n", "Open Play | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
2 | \n", "5 | \n", "Center Midfield | \n", "Right Foot | \n", "Half Volley | \n", "Open Play | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "48.9 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
3 | \n", "5 | \n", "Left Center Midfield | \n", "Right Foot | \n", "Normal | \n", "Open Play | \n", "False | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
4 | \n", "5 | \n", "Right Center Back | \n", "Left Foot | \n", "Normal | \n", "Open Play | \n", "True | \n", "False | \n", "False | \n", "False | \n", "False | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
5 rows × 64 columns
\n", "\n", " | Training MAE | \n", "Training RMSE | \n", "Training R2 | \n", "Testing MAE | \n", "Testing RMSE | \n", "Testing R2 | \n", "Training Time (mins) | \n", "
---|---|---|---|---|---|---|---|
Model Name | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
XG Boost | \n", "0.10305 | \n", "0.32102 | \n", "0.00230 | \n", "0.10497 | \n", "0.32399 | \n", "0.00009 | \n", "15.20037 | \n", "