ium_478831/IUM_main.ipynb
JulianZablonski c9c24167fe pytorch
2022-04-24 20:42:38 +02:00

604 lines
21 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: kaggle in c:\\users\\user\\anaconda3\\lib\\site-packages (1.5.12)\n",
"Requirement already satisfied: python-dateutil in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (2.8.2)\n",
"Requirement already satisfied: python-slugify in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (5.0.2)\n",
"Requirement already satisfied: urllib3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (1.26.7)\n",
"Requirement already satisfied: certifi in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (2021.10.8)\n",
"Requirement already satisfied: tqdm in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (4.62.3)\n",
"Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (2.26.0)\n",
"Requirement already satisfied: six>=1.10 in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (1.16.0)\n",
"Requirement already satisfied: text-unidecode>=1.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from python-slugify->kaggle) (1.3)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->kaggle) (2.0.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->kaggle) (3.2)\n",
"Requirement already satisfied: colorama in c:\\users\\user\\anaconda3\\lib\\site-packages (from tqdm->kaggle) (0.4.4)\n",
"Requirement already satisfied: pandas in c:\\users\\user\\anaconda3\\lib\\site-packages (1.3.4)\n",
"Requirement already satisfied: pytz>=2017.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas) (2021.3)\n",
"Requirement already satisfied: numpy>=1.17.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas) (1.20.3)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\n",
"Requirement already satisfied: seaborn in c:\\users\\user\\anaconda3\\lib\\site-packages (0.11.2)\n",
"Requirement already satisfied: scipy>=1.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (1.7.1)\n",
"Requirement already satisfied: numpy>=1.15 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (1.20.3)\n",
"Requirement already satisfied: matplotlib>=2.2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (3.4.3)\n",
"Requirement already satisfied: pandas>=0.23 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (1.3.4)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (1.3.1)\n",
"Requirement already satisfied: pillow>=6.2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (8.4.0)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (3.0.4)\n",
"Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (2.8.2)\n",
"Requirement already satisfied: cycler>=0.10 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (0.10.0)\n",
"Requirement already satisfied: six in c:\\users\\user\\anaconda3\\lib\\site-packages (from cycler>=0.10->matplotlib>=2.2->seaborn) (1.16.0)\n",
"Requirement already satisfied: pytz>=2017.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas>=0.23->seaborn) (2021.3)\n"
]
}
],
"source": [
"!pip install kaggle\n",
"!pip install pandas\n",
"!pip install seaborn"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"401 - Unauthorized\n"
]
}
],
"source": [
"!kaggle datasets download -d wenruliu/adult-income-dataset\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'unzip' is not recognized as an internal or external command,\n",
"operable program or batch file.\n"
]
}
],
"source": [
"!unzip -o adult-income-dataset.zip"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>educational-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" <th>income</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>226802</td>\n",
" <td>11th</td>\n",
" <td>7</td>\n",
" <td>Never-married</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Own-child</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>89814</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Farming-fishing</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>336951</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>12</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Protective-serv</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>160323</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>7688</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>18</td>\n",
" <td>?</td>\n",
" <td>103497</td>\n",
" <td>Some-college</td>\n",
" <td>10</td>\n",
" <td>Never-married</td>\n",
" <td>?</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48837</th>\n",
" <td>27</td>\n",
" <td>Private</td>\n",
" <td>257302</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>12</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Tech-support</td>\n",
" <td>Wife</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>38</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48838</th>\n",
" <td>40</td>\n",
" <td>Private</td>\n",
" <td>154374</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48839</th>\n",
" <td>58</td>\n",
" <td>Private</td>\n",
" <td>151910</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Widowed</td>\n",
" <td>Adm-clerical</td>\n",
" <td>Unmarried</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48840</th>\n",
" <td>22</td>\n",
" <td>Private</td>\n",
" <td>201490</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Never-married</td>\n",
" <td>Adm-clerical</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>20</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48841</th>\n",
" <td>52</td>\n",
" <td>Self-emp-inc</td>\n",
" <td>287927</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Wife</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>15024</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;50K</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>48842 rows × 15 columns</p>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education educational-num \\\n",
"0 25 Private 226802 11th 7 \n",
"1 38 Private 89814 HS-grad 9 \n",
"2 28 Local-gov 336951 Assoc-acdm 12 \n",
"3 44 Private 160323 Some-college 10 \n",
"4 18 ? 103497 Some-college 10 \n",
"... ... ... ... ... ... \n",
"48837 27 Private 257302 Assoc-acdm 12 \n",
"48838 40 Private 154374 HS-grad 9 \n",
"48839 58 Private 151910 HS-grad 9 \n",
"48840 22 Private 201490 HS-grad 9 \n",
"48841 52 Self-emp-inc 287927 HS-grad 9 \n",
"\n",
" marital-status occupation relationship race gender \\\n",
"0 Never-married Machine-op-inspct Own-child Black Male \n",
"1 Married-civ-spouse Farming-fishing Husband White Male \n",
"2 Married-civ-spouse Protective-serv Husband White Male \n",
"3 Married-civ-spouse Machine-op-inspct Husband Black Male \n",
"4 Never-married ? Own-child White Female \n",
"... ... ... ... ... ... \n",
"48837 Married-civ-spouse Tech-support Wife White Female \n",
"48838 Married-civ-spouse Machine-op-inspct Husband White Male \n",
"48839 Widowed Adm-clerical Unmarried White Female \n",
"48840 Never-married Adm-clerical Own-child White Male \n",
"48841 Married-civ-spouse Exec-managerial Wife White Female \n",
"\n",
" capital-gain capital-loss hours-per-week native-country income \n",
"0 0 0 40 United-States <=50K \n",
"1 0 0 50 United-States <=50K \n",
"2 0 0 40 United-States >50K \n",
"3 7688 0 40 United-States >50K \n",
"4 0 0 30 United-States <=50K \n",
"... ... ... ... ... ... \n",
"48837 0 0 38 United-States <=50K \n",
"48838 0 0 40 United-States >50K \n",
"48839 0 0 40 United-States <=50K \n",
"48840 0 0 20 United-States <=50K \n",
"48841 15024 0 40 United-States >50K \n",
"\n",
"[48842 rows x 15 columns]"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df=pd.read_csv('adult-income-dataset.csv')\n",
"df\n"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"df['income_if_<=50k'] = df['income'].apply(lambda x: True if x == '<=50K' else False)\n",
"df['if_male'] = df['gender'].apply(lambda x: True if x == 'Male' else False)\n",
"df['income_if_<=50k']= df['income_if_<=50k'].astype(int)\n",
"df['if_male']= df['if_male'].astype(int)\n"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"#usunięcie nie pełnych danych \n",
"df = df[df.workclass != '?']\n",
"df = df.reset_index()"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split\n",
"X, y = df[['age']], df['income_if_<=50k']\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=37)\n",
"n_samples, n_features = X.shape"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"X_train = np.array(X_train).reshape(-1,1)\n",
"X_test = np.array(X_test).reshape(-1,1)\n",
"y_train = np.array(y_train).reshape(-1,1)\n",
"y_test = np.array(y_test).reshape(-1,1)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"sc = StandardScaler()\n",
"X_train = sc.fit_transform(X_train)\n",
"X_test = sc.fit_transform(X_test)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"torch.from_file\n",
"X_train = torch.from_numpy(X_train.astype(np.float32))\n",
"X_test = torch.from_numpy(X_test.astype(np.float32))\n",
"y_train = torch.from_numpy(y_train.astype(np.float32))\n",
"y_test = torch.from_numpy(y_test.astype(np.float32))\n",
"\n",
"y_train = y_train.view(y_train.shape[0], 1)\n",
"y_test= y_test.view(y_test.shape[0], 1)\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"class LogisticRegresion(nn.Module):\n",
" def __init__(self, n_input_featuers):\n",
" super(LogisticRegresion, self).__init__()\n",
" self.linear = nn.Linear(n_input_featuers, 1)\n",
" \n",
" def forward(self, x):\n",
" y_predicted = torch.sigmoid(self.linear(x))\n",
" return y_predicted\n",
"\n",
"model = LogisticRegresion(n_features)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.BCELoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1,loss = 1.0032\n",
"epoch:101,loss = 0.8295\n",
"epoch:201,loss = 0.7194\n",
"epoch:301,loss = 0.6511\n",
"epoch:401,loss = 0.6088\n",
"epoch:501,loss = 0.5823\n",
"epoch:601,loss = 0.5656\n",
"epoch:701,loss = 0.5548\n",
"epoch:801,loss = 0.5478\n",
"epoch:901,loss = 0.5431\n",
"epoch:1001,loss = 0.5400\n",
"epoch:1101,loss = 0.5378\n",
"epoch:1201,loss = 0.5363\n",
"epoch:1301,loss = 0.5353\n",
"epoch:1401,loss = 0.5346\n"
]
}
],
"source": [
"num_epochs = 1500\n",
"for epoch in range(num_epochs):\n",
" y_predicted = model(X_train)\n",
" loss = criterion(y_predicted,y_train)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" if (epoch%100==0):\n",
" print(f'epoch:{epoch+1},loss = {loss.item():.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.7395\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" y_predicted = model(X_test)\n",
" y_predicted_cls = y_predicted.round()\n",
" acc = y_predicted_cls.eq(y_test).sum()/float(y_test.shape[0])\n",
" print(f'{acc:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = open(\"result_pytorch\",'w+')\n",
"result.write(f'acc:{acc:.4f}')"
]
}
],
"metadata": {
"interpreter": {
"hash": "2647ea34e536f865ab67ff9ddee7fd78773d956cec0cab53c79b32cd10da5d83"
},
"kernelspec": {
"display_name": "Python 3.9.11 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}