Dodanie nootbooka z modelem regresji logistycznej dla xG

This commit is contained in:
Maciej Chmielarz 2023-12-05 14:40:11 +01:00
parent 562e5e9a43
commit abdac28187
3 changed files with 10772 additions and 0 deletions

10166
notebooks/data4.csv Normal file

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -0,0 +1,606 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Importy"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wczytanie danych"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('data4.csv')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"y = pd.DataFrame(df['isGoal'])\n",
"X = df.drop(['isGoal'], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\anaconda3\\lib\\site-packages\\scipy\\__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>match_minute</th>\n",
" <th>match_second</th>\n",
" <th>position_x</th>\n",
" <th>position_y</th>\n",
" <th>play_type</th>\n",
" <th>BodyPart</th>\n",
" <th>Number_Intervening_Opponents</th>\n",
" <th>Number_Intervening_Teammates</th>\n",
" <th>Interference_on_Shooter</th>\n",
" <th>outcome</th>\n",
" <th>...</th>\n",
" <th>Interference_on_Shooter_Code</th>\n",
" <th>distance_to_goalM</th>\n",
" <th>distance_to_centerM</th>\n",
" <th>angle</th>\n",
" <th>isFoot</th>\n",
" <th>isHead</th>\n",
" <th>header_distance_to_goalM</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Medium</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8767</th>\n",
" <td>13</td>\n",
" <td>28</td>\n",
" <td>9.23</td>\n",
" <td>-2.24</td>\n",
" <td>Open Play</td>\n",
" <td>Head</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>Medium</td>\n",
" <td>Goal</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>9.499168</td>\n",
" <td>2.245283</td>\n",
" <td>13.672174</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>9.499168</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5798</th>\n",
" <td>78</td>\n",
" <td>9</td>\n",
" <td>14.46</td>\n",
" <td>12.72</td>\n",
" <td>Open Play</td>\n",
" <td>Left</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>Low</td>\n",
" <td>Saved</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>19.278332</td>\n",
" <td>12.750000</td>\n",
" <td>41.404002</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6018</th>\n",
" <td>78</td>\n",
" <td>27</td>\n",
" <td>9.73</td>\n",
" <td>14.22</td>\n",
" <td>Open Play</td>\n",
" <td>Left</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>Low</td>\n",
" <td>Missed</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>17.257933</td>\n",
" <td>14.253538</td>\n",
" <td>55.681087</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4961</th>\n",
" <td>34</td>\n",
" <td>34</td>\n",
" <td>34.91</td>\n",
" <td>0.25</td>\n",
" <td>Open Play</td>\n",
" <td>Right</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>Low</td>\n",
" <td>Saved</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>34.910899</td>\n",
" <td>0.250590</td>\n",
" <td>0.411271</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>447</th>\n",
" <td>52</td>\n",
" <td>57</td>\n",
" <td>26.93</td>\n",
" <td>1.00</td>\n",
" <td>Open Play</td>\n",
" <td>Left</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>Medium</td>\n",
" <td>Saved</td>\n",
" <td>...</td>\n",
" <td>2</td>\n",
" <td>26.948648</td>\n",
" <td>1.002358</td>\n",
" <td>2.131616</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0.000000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 29 columns</p>\n",
"</div>"
],
"text/plain": [
" match_minute match_second position_x position_y play_type BodyPart \\\n",
"8767 13 28 9.23 -2.24 Open Play Head \n",
"5798 78 9 14.46 12.72 Open Play Left \n",
"6018 78 27 9.73 14.22 Open Play Left \n",
"4961 34 34 34.91 0.25 Open Play Right \n",
"447 52 57 26.93 1.00 Open Play Left \n",
"\n",
" Number_Intervening_Opponents Number_Intervening_Teammates \\\n",
"8767 3 0 \n",
"5798 3 0 \n",
"6018 2 0 \n",
"4961 4 1 \n",
"447 2 0 \n",
"\n",
" Interference_on_Shooter outcome ... Interference_on_Shooter_Code \\\n",
"8767 Medium Goal ... 2 \n",
"5798 Low Saved ... 1 \n",
"6018 Low Missed ... 1 \n",
"4961 Low Saved ... 1 \n",
"447 Medium Saved ... 2 \n",
"\n",
" distance_to_goalM distance_to_centerM angle isFoot isHead \\\n",
"8767 9.499168 2.245283 13.672174 0 1 \n",
"5798 19.278332 12.750000 41.404002 1 0 \n",
"6018 17.257933 14.253538 55.681087 1 0 \n",
"4961 34.910899 0.250590 0.411271 1 0 \n",
"447 26.948648 1.002358 2.131616 1 0 \n",
"\n",
" header_distance_to_goalM High Low Medium \n",
"8767 9.499168 0 0 1 \n",
"5798 0.000000 0 1 0 \n",
"6018 0.000000 0 1 0 \n",
"4961 0.000000 0 1 0 \n",
"447 0.000000 0 0 1 \n",
"\n",
"[5 rows x 29 columns]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>isGoal</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8767</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5798</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6018</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4961</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>447</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" isGoal\n",
"8767 1\n",
"5798 0\n",
"6018 0\n",
"4961 0\n",
"447 0"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Przygotowanie danych"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['match_minute', 'match_second', 'position_x', 'position_y', 'play_type',\n",
" 'BodyPart', 'Number_Intervening_Opponents',\n",
" 'Number_Intervening_Teammates', 'Interference_on_Shooter', 'outcome',\n",
" 'position_xM', 'position_yM', 'position_xM_r', 'position_yM_r',\n",
" 'position_xM_std', 'position_yM_std', 'position_xM_std_r',\n",
" 'position_yM_std_r', 'BodyPartCode', 'Interference_on_Shooter_Code',\n",
" 'distance_to_goalM', 'distance_to_centerM', 'angle', 'isFoot', 'isHead',\n",
" 'header_distance_to_goalM', 'High', 'Low', 'Medium'],\n",
" dtype='object')"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.columns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Uwzględnienie wybranych cech: \n",
"- Współrzędna x strzelającego,\n",
"- Współrzędna y strzelającego,\n",
"- Dystans do bramki,\n",
"- Kąt do bramki,\n",
"- Minuta meczu,\n",
"- Liczba przeciwników przed piłką,\n",
"- Liczba zawodników ze swojej drużyny przed piłką,\n",
"- Część ciała."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"X_train_extracted = X_train[['position_x', 'position_y', 'distance_to_goalM', \n",
" 'angle', 'match_minute', 'Number_Intervening_Opponents', \n",
" 'Number_Intervening_Teammates', 'isFoot', 'isHead']]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"X_test_extracted = X_test[['position_x', 'position_y', 'distance_to_goalM', \n",
" 'angle', 'match_minute', 'Number_Intervening_Opponents', \n",
" 'Number_Intervening_Teammates', 'isFoot', 'isHead']]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Trening danych"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:1143: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression(max_iter=500)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(max_iter=500)</pre></div></div></div></div></div>"
],
"text/plain": [
"LogisticRegression(max_iter=500)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = LogisticRegression(max_iter=500)\n",
"model.fit(X_train_extracted, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ewaluacja modelu"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Zbiór danych testowych zawiera 2033 oddane strzały, gdzie 236 to strzały trafione.\n",
"Dokładność klasyfikacji, czy strzał jest bramką, czy nie, wynosi 0.89%.\n",
"klasyfikator uzyskał ROC-AUC na poziomie 0.76%.\n"
]
}
],
"source": [
"from sklearn.metrics import roc_auc_score\n",
"\n",
"print(f'Zbiór danych testowych zawiera {len(y_test)} oddane strzały, gdzie {y_test.sum()[\"isGoal\"]} to strzały trafione.')\n",
"print(f'Dokładność klasyfikacji, czy strzał jest bramką, czy nie, wynosi {model.score(X_test_extracted, y_test):.2f}%.')\n",
"print(f'klasyfikator uzyskał ROC-AUC na poziomie {roc_auc_score(y_test, model.predict_proba(X_test_extracted)[:, 1]):.2f}%.')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.89 0.99 0.94 1797\n",
" 1 0.59 0.09 0.16 236\n",
"\n",
" accuracy 0.89 2033\n",
" macro avg 0.74 0.54 0.55 2033\n",
"weighted avg 0.86 0.89 0.85 2033\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(y_test,model.predict(X_test_extracted)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Zapisywanie modelu"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['regresja_logistyczna.joblib']"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from joblib import dump\n",
"dump(model, 'regresja_logistyczna.joblib') "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Wczytywanie modelu"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"from joblib import load\n",
"\n",
"model2 = load('regresja_logistyczna.joblib')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}