{ "cells": [ { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: kaggle in c:\\users\\user\\anaconda3\\lib\\site-packages (1.5.12)\n", "Requirement already satisfied: python-dateutil in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (2.8.2)\n", "Requirement already satisfied: python-slugify in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (5.0.2)\n", "Requirement already satisfied: urllib3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (1.26.7)\n", "Requirement already satisfied: certifi in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (2021.10.8)\n", "Requirement already satisfied: tqdm in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (4.62.3)\n", "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (2.26.0)\n", "Requirement already satisfied: six>=1.10 in c:\\users\\user\\anaconda3\\lib\\site-packages (from kaggle) (1.16.0)\n", "Requirement already satisfied: text-unidecode>=1.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from python-slugify->kaggle) (1.3)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->kaggle) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->kaggle) (3.2)\n", "Requirement already satisfied: colorama in c:\\users\\user\\anaconda3\\lib\\site-packages (from tqdm->kaggle) (0.4.4)\n", "Requirement already satisfied: pandas in c:\\users\\user\\anaconda3\\lib\\site-packages (1.3.4)\n", "Requirement already satisfied: pytz>=2017.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas) (2021.3)\n", "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas) (1.20.3)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas) (2.8.2)\n", "Requirement already satisfied: six>=1.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\n", "Requirement already satisfied: seaborn in c:\\users\\user\\anaconda3\\lib\\site-packages (0.11.2)\n", "Requirement already satisfied: scipy>=1.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (1.7.1)\n", "Requirement already satisfied: numpy>=1.15 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (1.20.3)\n", "Requirement already satisfied: matplotlib>=2.2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (3.4.3)\n", "Requirement already satisfied: pandas>=0.23 in c:\\users\\user\\anaconda3\\lib\\site-packages (from seaborn) (1.3.4)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (1.3.1)\n", "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (8.4.0)\n", "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (3.0.4)\n", "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (2.8.2)\n", "Requirement already satisfied: cycler>=0.10 in c:\\users\\user\\anaconda3\\lib\\site-packages (from matplotlib>=2.2->seaborn) (0.10.0)\n", "Requirement already satisfied: six in c:\\users\\user\\anaconda3\\lib\\site-packages (from cycler>=0.10->matplotlib>=2.2->seaborn) (1.16.0)\n", "Requirement already satisfied: pytz>=2017.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas>=0.23->seaborn) (2021.3)\n" ] } ], "source": [ "!pip install kaggle\n", "!pip install pandas\n", "!pip install seaborn" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "401 - Unauthorized\n" ] } ], "source": [ "!kaggle datasets download -d wenruliu/adult-income-dataset\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "'unzip' is not recognized as an internal or external command,\n", "operable program or batch file.\n" ] } ], "source": [ "!unzip -o adult-income-dataset.zip" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducational-nummarital-statusoccupationrelationshipracegendercapital-gaincapital-losshours-per-weeknative-countryincome
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States<=50K
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States<=50K
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States>50K
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States>50K
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States<=50K
................................................
4883727Private257302Assoc-acdm12Married-civ-spouseTech-supportWifeWhiteFemale0038United-States<=50K
4883840Private154374HS-grad9Married-civ-spouseMachine-op-inspctHusbandWhiteMale0040United-States>50K
4883958Private151910HS-grad9WidowedAdm-clericalUnmarriedWhiteFemale0040United-States<=50K
4884022Private201490HS-grad9Never-marriedAdm-clericalOwn-childWhiteMale0020United-States<=50K
4884152Self-emp-inc287927HS-grad9Married-civ-spouseExec-managerialWifeWhiteFemale15024040United-States>50K
\n", "

48842 rows × 15 columns

\n", "
" ], "text/plain": [ " age workclass fnlwgt education educational-num \\\n", "0 25 Private 226802 11th 7 \n", "1 38 Private 89814 HS-grad 9 \n", "2 28 Local-gov 336951 Assoc-acdm 12 \n", "3 44 Private 160323 Some-college 10 \n", "4 18 ? 103497 Some-college 10 \n", "... ... ... ... ... ... \n", "48837 27 Private 257302 Assoc-acdm 12 \n", "48838 40 Private 154374 HS-grad 9 \n", "48839 58 Private 151910 HS-grad 9 \n", "48840 22 Private 201490 HS-grad 9 \n", "48841 52 Self-emp-inc 287927 HS-grad 9 \n", "\n", " marital-status occupation relationship race gender \\\n", "0 Never-married Machine-op-inspct Own-child Black Male \n", "1 Married-civ-spouse Farming-fishing Husband White Male \n", "2 Married-civ-spouse Protective-serv Husband White Male \n", "3 Married-civ-spouse Machine-op-inspct Husband Black Male \n", "4 Never-married ? Own-child White Female \n", "... ... ... ... ... ... \n", "48837 Married-civ-spouse Tech-support Wife White Female \n", "48838 Married-civ-spouse Machine-op-inspct Husband White Male \n", "48839 Widowed Adm-clerical Unmarried White Female \n", "48840 Never-married Adm-clerical Own-child White Male \n", "48841 Married-civ-spouse Exec-managerial Wife White Female \n", "\n", " capital-gain capital-loss hours-per-week native-country income \n", "0 0 0 40 United-States <=50K \n", "1 0 0 50 United-States <=50K \n", "2 0 0 40 United-States >50K \n", "3 7688 0 40 United-States >50K \n", "4 0 0 30 United-States <=50K \n", "... ... ... ... ... ... \n", "48837 0 0 38 United-States <=50K \n", "48838 0 0 40 United-States >50K \n", "48839 0 0 40 United-States <=50K \n", "48840 0 0 20 United-States <=50K \n", "48841 15024 0 40 United-States >50K \n", "\n", "[48842 rows x 15 columns]" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df=pd.read_csv('adult-income-dataset.csv')\n", "df\n" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "df['income_if_<=50k'] = df['income'].apply(lambda x: True if x == '<=50K' else False)\n", "df['if_male'] = df['gender'].apply(lambda x: True if x == 'Male' else False)\n", "df['income_if_<=50k']= df['income_if_<=50k'].astype(int)\n", "df['if_male']= df['if_male'].astype(int)\n" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "#usunięcie nie pełnych danych \n", "df = df[df.workclass != '?']\n", "df = df.reset_index()" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.model_selection import train_test_split\n", "X, y = df[['age']], df['income_if_<=50k']\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=37)\n", "n_samples, n_features = X.shape" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "X_train = np.array(X_train).reshape(-1,1)\n", "X_test = np.array(X_test).reshape(-1,1)\n", "y_train = np.array(y_train).reshape(-1,1)\n", "y_test = np.array(y_test).reshape(-1,1)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "sc = StandardScaler()\n", "X_train = sc.fit_transform(X_train)\n", "X_test = sc.fit_transform(X_test)\n", "\n" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "import torch\n", "torch.from_file\n", "X_train = torch.from_numpy(X_train.astype(np.float32))\n", "X_test = torch.from_numpy(X_test.astype(np.float32))\n", "y_train = torch.from_numpy(y_train.astype(np.float32))\n", "y_test = torch.from_numpy(y_test.astype(np.float32))\n", "\n", "y_train = y_train.view(y_train.shape[0], 1)\n", "y_test= y_test.view(y_test.shape[0], 1)\n" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "class LogisticRegresion(nn.Module):\n", " def __init__(self, n_input_featuers):\n", " super(LogisticRegresion, self).__init__()\n", " self.linear = nn.Linear(n_input_featuers, 1)\n", " \n", " def forward(self, x):\n", " y_predicted = torch.sigmoid(self.linear(x))\n", " return y_predicted\n", "\n", "model = LogisticRegresion(n_features)" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "criterion = nn.BCELoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch:1,loss = 1.0032\n", "epoch:101,loss = 0.8295\n", "epoch:201,loss = 0.7194\n", "epoch:301,loss = 0.6511\n", "epoch:401,loss = 0.6088\n", "epoch:501,loss = 0.5823\n", "epoch:601,loss = 0.5656\n", "epoch:701,loss = 0.5548\n", "epoch:801,loss = 0.5478\n", "epoch:901,loss = 0.5431\n", "epoch:1001,loss = 0.5400\n", "epoch:1101,loss = 0.5378\n", "epoch:1201,loss = 0.5363\n", "epoch:1301,loss = 0.5353\n", "epoch:1401,loss = 0.5346\n" ] } ], "source": [ "num_epochs = 1500\n", "for epoch in range(num_epochs):\n", " y_predicted = model(X_train)\n", " loss = criterion(y_predicted,y_train)\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " if (epoch%100==0):\n", " print(f'epoch:{epoch+1},loss = {loss.item():.4f}')" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7395\n" ] } ], "source": [ "with torch.no_grad():\n", " y_predicted = model(X_test)\n", " y_predicted_cls = y_predicted.round()\n", " acc = y_predicted_cls.eq(y_test).sum()/float(y_test.shape[0])\n", " print(f'{acc:.4f}')" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result = open(\"result_pytorch\",'w+')\n", "result.write(f'acc:{acc:.4f}')" ] } ], "metadata": { "interpreter": { "hash": "2647ea34e536f865ab67ff9ddee7fd78773d956cec0cab53c79b32cd10da5d83" }, "kernelspec": { "display_name": "Python 3.9.11 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }