1155 lines
58 KiB
Plaintext
1155 lines
58 KiB
Plaintext
{
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0,
|
||
"metadata": {
|
||
"colab": {
|
||
"name": "UMA_projekt.ipynb",
|
||
"provenance": [],
|
||
"collapsed_sections": []
|
||
},
|
||
"kernelspec": {
|
||
"name": "python3",
|
||
"display_name": "Python 3"
|
||
},
|
||
"language_info": {
|
||
"name": "python"
|
||
},
|
||
"accelerator": "GPU",
|
||
"gpuClass": "standard"
|
||
},
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import time, gc\n",
|
||
"\n",
|
||
"# Timing utilities\n",
|
||
"start_time = None\n",
|
||
"\n",
|
||
"def start_timer():\n",
|
||
" global start_time\n",
|
||
" gc.collect()\n",
|
||
" torch.cuda.empty_cache()\n",
|
||
" torch.cuda.reset_max_memory_allocated()\n",
|
||
" torch.cuda.synchronize()\n",
|
||
" start_time = time.time()\n",
|
||
"\n",
|
||
"def end_timer_and_print(local_msg):\n",
|
||
" torch.cuda.synchronize()\n",
|
||
" end_time = time.time()\n",
|
||
" print(\"\\n\" + local_msg)\n",
|
||
" print(\"Total execution time = {:.3f} sec\".format(end_time - start_time))\n",
|
||
" print(\"Max memory used by tensors = {} bytes\".format(torch.cuda.max_memory_allocated()))"
|
||
],
|
||
"metadata": {
|
||
"id": "tWf7BQXI3Epz"
|
||
},
|
||
"execution_count": 232,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 233,
|
||
"metadata": {
|
||
"id": "OFdF8yc6z9QK",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "6b2863d8-cbd3-40c8-f356-c57efb696aae"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"output_type": "execute_result",
|
||
"data": {
|
||
"text/plain": [
|
||
"[name: \"/device:CPU:0\"\n",
|
||
" device_type: \"CPU\"\n",
|
||
" memory_limit: 268435456\n",
|
||
" locality {\n",
|
||
" }\n",
|
||
" incarnation: 7116988186229065702\n",
|
||
" xla_global_id: -1, name: \"/device:GPU:0\"\n",
|
||
" device_type: \"GPU\"\n",
|
||
" memory_limit: 14465892352\n",
|
||
" locality {\n",
|
||
" bus_id: 1\n",
|
||
" links {\n",
|
||
" }\n",
|
||
" }\n",
|
||
" incarnation: 10048785647988876421\n",
|
||
" physical_device_desc: \"device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5\"\n",
|
||
" xla_global_id: 416903419]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"execution_count": 233
|
||
}
|
||
],
|
||
"source": [
|
||
"from tensorflow.python.client import device_lib\n",
|
||
"device_lib.list_local_devices()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from sklearn.datasets import fetch_20newsgroups\n",
|
||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"import torch\n",
|
||
"import scipy"
|
||
],
|
||
"metadata": {
|
||
"id": "TIdeqZPs0aON"
|
||
},
|
||
"execution_count": 234,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# !unzip real-or-fake-fake-jobposting-prediction.zip"
|
||
],
|
||
"metadata": {
|
||
"id": "Rf2cOL69qJ7D"
|
||
},
|
||
"execution_count": 235,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"data = pd.read_csv('fake_job_postings.csv', engine='python')\n",
|
||
"data = data[[\"company_profile\", \"fraudulent\"]]\n",
|
||
"data = data.sample(frac=1)\n",
|
||
"data = data.dropna()"
|
||
],
|
||
"metadata": {
|
||
"id": "NO98S-QDsV6j"
|
||
},
|
||
"execution_count": 236,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"data"
|
||
],
|
||
"metadata": {
|
||
"id": "3Xd0Uvi4stMg",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/",
|
||
"height": 424
|
||
},
|
||
"outputId": "55e073ac-74f2-44f9-90de-d39094f84369"
|
||
},
|
||
"execution_count": 237,
|
||
"outputs": [
|
||
{
|
||
"output_type": "execute_result",
|
||
"data": {
|
||
"text/plain": [
|
||
" company_profile fraudulent\n",
|
||
"16503 At Hayes-Corp, we create the fun stuff. With ... 0\n",
|
||
"16706 Tribal Worldwide Athens is a digitally centric... 0\n",
|
||
"3364 About ECHOING GREEN: Echoing Green unleashes ... 0\n",
|
||
"16856 Daily Secret is the fastest growing digital me... 0\n",
|
||
"1566 ding* is the world’s largest top-up provider. ... 0\n",
|
||
"... ... ...\n",
|
||
"7607 Established on the principles that full time e... 0\n",
|
||
"682 AGOGO creates a personalized audio channel by ... 0\n",
|
||
"2759 We are a family run business that has been in ... 0\n",
|
||
"5751 We have aggressive growth plans in place for t... 1\n",
|
||
"3629 Want to build a 21st century financial service... 0\n",
|
||
"\n",
|
||
"[14572 rows x 2 columns]"
|
||
],
|
||
"text/html": [
|
||
"\n",
|
||
" <div id=\"df-8fdfd669-e598-4796-a993-469e8991a57c\">\n",
|
||
" <div class=\"colab-df-container\">\n",
|
||
" <div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>company_profile</th>\n",
|
||
" <th>fraudulent</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>16503</th>\n",
|
||
" <td>At Hayes-Corp, we create the fun stuff. With ...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16706</th>\n",
|
||
" <td>Tribal Worldwide Athens is a digitally centric...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3364</th>\n",
|
||
" <td>About ECHOING GREEN: Echoing Green unleashes ...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16856</th>\n",
|
||
" <td>Daily Secret is the fastest growing digital me...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1566</th>\n",
|
||
" <td>ding* is the world’s largest top-up provider. ...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>7607</th>\n",
|
||
" <td>Established on the principles that full time e...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>682</th>\n",
|
||
" <td>AGOGO creates a personalized audio channel by ...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2759</th>\n",
|
||
" <td>We are a family run business that has been in ...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5751</th>\n",
|
||
" <td>We have aggressive growth plans in place for t...</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3629</th>\n",
|
||
" <td>Want to build a 21st century financial service...</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>14572 rows × 2 columns</p>\n",
|
||
"</div>\n",
|
||
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-8fdfd669-e598-4796-a993-469e8991a57c')\"\n",
|
||
" title=\"Convert this dataframe to an interactive table.\"\n",
|
||
" style=\"display:none;\">\n",
|
||
" \n",
|
||
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
|
||
" width=\"24px\">\n",
|
||
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
|
||
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
|
||
" </svg>\n",
|
||
" </button>\n",
|
||
" \n",
|
||
" <style>\n",
|
||
" .colab-df-container {\n",
|
||
" display:flex;\n",
|
||
" flex-wrap:wrap;\n",
|
||
" gap: 12px;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .colab-df-convert {\n",
|
||
" background-color: #E8F0FE;\n",
|
||
" border: none;\n",
|
||
" border-radius: 50%;\n",
|
||
" cursor: pointer;\n",
|
||
" display: none;\n",
|
||
" fill: #1967D2;\n",
|
||
" height: 32px;\n",
|
||
" padding: 0 0 0 0;\n",
|
||
" width: 32px;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .colab-df-convert:hover {\n",
|
||
" background-color: #E2EBFA;\n",
|
||
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
|
||
" fill: #174EA6;\n",
|
||
" }\n",
|
||
"\n",
|
||
" [theme=dark] .colab-df-convert {\n",
|
||
" background-color: #3B4455;\n",
|
||
" fill: #D2E3FC;\n",
|
||
" }\n",
|
||
"\n",
|
||
" [theme=dark] .colab-df-convert:hover {\n",
|
||
" background-color: #434B5C;\n",
|
||
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
|
||
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
|
||
" fill: #FFFFFF;\n",
|
||
" }\n",
|
||
" </style>\n",
|
||
"\n",
|
||
" <script>\n",
|
||
" const buttonEl =\n",
|
||
" document.querySelector('#df-8fdfd669-e598-4796-a993-469e8991a57c button.colab-df-convert');\n",
|
||
" buttonEl.style.display =\n",
|
||
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
|
||
"\n",
|
||
" async function convertToInteractive(key) {\n",
|
||
" const element = document.querySelector('#df-8fdfd669-e598-4796-a993-469e8991a57c');\n",
|
||
" const dataTable =\n",
|
||
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
|
||
" [key], {});\n",
|
||
" if (!dataTable) return;\n",
|
||
"\n",
|
||
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
|
||
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
|
||
" + ' to learn more about interactive tables.';\n",
|
||
" element.innerHTML = '';\n",
|
||
" dataTable['output_type'] = 'display_data';\n",
|
||
" await google.colab.output.renderOutput(dataTable, element);\n",
|
||
" const docLink = document.createElement('div');\n",
|
||
" docLink.innerHTML = docLinkHtml;\n",
|
||
" element.appendChild(docLink);\n",
|
||
" }\n",
|
||
" </script>\n",
|
||
" </div>\n",
|
||
" </div>\n",
|
||
" "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"execution_count": 237
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"data_train, data_test = train_test_split(data, test_size=2000, random_state=1)\n",
|
||
"data_dev, data_test = train_test_split(data_test, test_size=1000, random_state=1)\n",
|
||
"len(data_train), len(data_dev), len(data_test)"
|
||
],
|
||
"metadata": {
|
||
"id": "z02lLLumsyIY",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "fd85eca6-ba3f-4717-f428-3e84b4f5ec37"
|
||
},
|
||
"execution_count": 238,
|
||
"outputs": [
|
||
{
|
||
"output_type": "execute_result",
|
||
"data": {
|
||
"text/plain": [
|
||
"(12572, 1000, 1000)"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"execution_count": 238
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"x_train = data_train[\"company_profile\"]\n",
|
||
"x_dev = data_dev[\"company_profile\"]\n",
|
||
"x_test = data_test[\"company_profile\"]\n",
|
||
"\n",
|
||
"y_train = data_train[\"fraudulent\"]\n",
|
||
"y_dev = data_dev[\"fraudulent\"]\n",
|
||
"y_test = data_test[\"fraudulent\"]\n",
|
||
"\n",
|
||
"x_train = np.array(x_train)\n",
|
||
"x_dev = np.array(x_dev)\n",
|
||
"x_test = np.array(x_test)\n",
|
||
"\n",
|
||
"y_train = np.array(y_train)\n",
|
||
"y_dev = np.array(y_dev)\n",
|
||
"y_test = np.array(y_test)\n",
|
||
"\n",
|
||
"\n",
|
||
"y_train_np = np.array(y_train)\n",
|
||
"y_dev_np = np.array(y_dev)\n",
|
||
"y_test_np = np.array(y_test)"
|
||
],
|
||
"metadata": {
|
||
"id": "MRiv9loJtPTv"
|
||
},
|
||
"execution_count": 239,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"vectorizer = TfidfVectorizer()"
|
||
],
|
||
"metadata": {
|
||
"id": "8DbBdHRztYW-"
|
||
},
|
||
"execution_count": 240,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import copy\n",
|
||
"x_train = vectorizer.fit_transform(x_train)\n",
|
||
"x_dev = vectorizer.transform(x_dev)\n",
|
||
"x_test = vectorizer.transform(x_test)\n",
|
||
"\n",
|
||
"x_train_np = x_train.copy()\n",
|
||
"x_dev_np = x_dev.copy()\n",
|
||
"x_test_np = x_test.copy()"
|
||
],
|
||
"metadata": {
|
||
"id": "o_o0IxMGtizH"
|
||
},
|
||
"execution_count": 241,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"device = 'cuda'"
|
||
],
|
||
"metadata": {
|
||
"id": "ptyeXfURyKKj"
|
||
},
|
||
"execution_count": 242,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"x_train = torch.tensor(scipy.sparse.csr_matrix.todense(x_train), device=device).float()\n",
|
||
"x_dev = torch.tensor(scipy.sparse.csr_matrix.todense(x_dev), device=device).float()\n",
|
||
"x_test = torch.tensor(scipy.sparse.csr_matrix.todense(x_test), device=device).float()\n",
|
||
"\n",
|
||
"y_train = torch.tensor(y_train, device=device)\n",
|
||
"y_dev = torch.tensor(y_dev, device=device)\n",
|
||
"y_test = torch.tensor(y_test, device=device)"
|
||
],
|
||
"metadata": {
|
||
"id": "P6xaoW-Itxjo"
|
||
},
|
||
"execution_count": 243,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from sklearn.linear_model import LogisticRegression\n",
|
||
"start_timer()\n",
|
||
"reg = LogisticRegression().fit(x_train_np, y_train_np)\n",
|
||
"end_timer_and_print(\"Logistic regression: \")"
|
||
],
|
||
"metadata": {
|
||
"id": "n26vQSlY3un8",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "fdb865d0-c6e9-446c-9fb7-6e152d1789bb"
|
||
},
|
||
"execution_count": 244,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stderr",
|
||
"text": [
|
||
"/usr/local/lib/python3.7/dist-packages/torch/cuda/memory.py:274: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
|
||
" FutureWarning)\n"
|
||
]
|
||
},
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"\n",
|
||
"Logistic regression: \n",
|
||
"Total execution time = 0.365 sec\n",
|
||
"Max memory used by tensors = 2335263744 bytes\n"
|
||
]
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from sklearn.metrics import f1_score\n",
|
||
"from sklearn.metrics import accuracy_score\n",
|
||
"\n",
|
||
"y_pred_np = reg.predict(x_test_np)\n",
|
||
"print('F-score: ', f1_score(y_test_np, y_pred_np, average='macro'))\n",
|
||
"\n",
|
||
"print('Accuracy: ', accuracy_score(y_test_np, y_pred_np))"
|
||
],
|
||
"metadata": {
|
||
"id": "525TSZ_T35C5",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "e72d98a0-56bf-4b4e-b00a-b733bea25fe8"
|
||
},
|
||
"execution_count": 245,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"F-score: 0.8685964220682922\n",
|
||
"Accuracy: 0.993\n"
|
||
]
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"device=\"cuda\""
|
||
],
|
||
"metadata": {
|
||
"id": "HN7ZIaL17IAt"
|
||
},
|
||
"execution_count": 246,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def prepare_batches(X, Y, batch_size):\n",
|
||
" data_X = []\n",
|
||
" data_Y = []\n",
|
||
" for i in range(0, len(X)-1, batch_size):\n",
|
||
" data_X.append(X[i:i+batch_size])\n",
|
||
" data_Y.append(Y[i:i+batch_size].reshape(-1,1))\n",
|
||
" data_X = data_X[0:-1]\n",
|
||
" data_Y = data_Y[0:-1]\n",
|
||
" return data_X, data_Y"
|
||
],
|
||
"metadata": {
|
||
"id": "EvHChAIeICJu"
|
||
},
|
||
"execution_count": 247,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"size = 512\n",
|
||
"epochs = 150\n",
|
||
"\n"
|
||
],
|
||
"metadata": {
|
||
"id": "tOrxxJU93N22"
|
||
},
|
||
"execution_count": 248,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from torch import nn\n",
|
||
"from torch import optim\n",
|
||
"model = nn.Sequential(\n",
|
||
" nn.Linear(x_train.shape[1], size),\n",
|
||
" nn.ReLU(),\n",
|
||
" # nn.Linear(64, data_train[\"fraudulent\"].nunique()),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, data_train[\"fraudulent\"].nunique()),\n",
|
||
" \n",
|
||
" nn.LogSoftmax(dim=1))\n",
|
||
"model.cuda()\n",
|
||
"# Define the loss\n",
|
||
"criterion = nn.NLLLoss() # Forward pass, log\n",
|
||
"logps = model(x_train) # Calculate the loss with the logits and the labels\n",
|
||
"loss = criterion(logps, y_train)\n",
|
||
"loss.backward() # Optimizers need parameters to optimize and a learning rate\n",
|
||
"optimizer = optim.Adam(model.parameters(), lr=0.002)\n",
|
||
"\n",
|
||
"train_losses = []\n",
|
||
"test_losses = []\n",
|
||
"test_accuracies = []\n",
|
||
"start_timer()\n",
|
||
"for e in range(epochs):\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" output = model.forward(x_train)\n",
|
||
" loss = criterion(output, y_train)\n",
|
||
" loss.backward()\n",
|
||
" train_loss = loss.item()\n",
|
||
" train_losses.append(train_loss)\n",
|
||
"\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" # Turn off gradients for validation, saves memory and computations\n",
|
||
" with torch.no_grad():\n",
|
||
" model.eval()\n",
|
||
" log_ps = model(x_dev)\n",
|
||
" test_loss = criterion(log_ps, y_dev)\n",
|
||
" test_losses.append(test_loss)\n",
|
||
"\n",
|
||
" ps = torch.exp(log_ps)\n",
|
||
" top_p, top_class = ps.topk(1, dim=1)\n",
|
||
" equals = top_class == y_dev.view(*top_class.shape)\n",
|
||
" test_accuracy = torch.mean(equals.float())\n",
|
||
" test_accuracies.append(test_accuracy)\n",
|
||
"\n",
|
||
" model.train()\n",
|
||
"\n",
|
||
" print(f\"Epoch: {e + 1}/{epochs}.. \",\n",
|
||
" f\"Training Loss: {train_loss:.3f}.. \",\n",
|
||
" f\"Test Loss: {test_loss:.3f}.. \",\n",
|
||
" f\"Test Accuracy: {test_accuracy:.3f}\")\n",
|
||
"end_timer_and_print(\"Mixed precision:\")"
|
||
],
|
||
"metadata": {
|
||
"id": "RQGLH1lbxSqD",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "7b748a39-45f1-4d2c-eb4e-1bbd72dcb580"
|
||
},
|
||
"execution_count": 249,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stderr",
|
||
"text": [
|
||
"/usr/local/lib/python3.7/dist-packages/torch/cuda/memory.py:274: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
|
||
" FutureWarning)\n"
|
||
]
|
||
},
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"Epoch: 1/150.. Training Loss: 0.666.. Test Loss: 0.580.. Test Accuracy: 0.983\n",
|
||
"Epoch: 2/150.. Training Loss: 0.581.. Test Loss: 0.454.. Test Accuracy: 0.983\n",
|
||
"Epoch: 3/150.. Training Loss: 0.455.. Test Loss: 0.191.. Test Accuracy: 0.983\n",
|
||
"Epoch: 4/150.. Training Loss: 0.195.. Test Loss: 0.103.. Test Accuracy: 0.983\n",
|
||
"Epoch: 5/150.. Training Loss: 0.115.. Test Loss: 0.177.. Test Accuracy: 0.983\n",
|
||
"Epoch: 6/150.. Training Loss: 0.193.. Test Loss: 0.166.. Test Accuracy: 0.983\n",
|
||
"Epoch: 7/150.. Training Loss: 0.178.. Test Loss: 0.122.. Test Accuracy: 0.983\n",
|
||
"Epoch: 8/150.. Training Loss: 0.131.. Test Loss: 0.085.. Test Accuracy: 0.983\n",
|
||
"Epoch: 9/150.. Training Loss: 0.093.. Test Loss: 0.072.. Test Accuracy: 0.983\n",
|
||
"Epoch: 10/150.. Training Loss: 0.079.. Test Loss: 0.091.. Test Accuracy: 0.983\n",
|
||
"Epoch: 11/150.. Training Loss: 0.096.. Test Loss: 0.098.. Test Accuracy: 0.983\n",
|
||
"Epoch: 12/150.. Training Loss: 0.103.. Test Loss: 0.081.. Test Accuracy: 0.983\n",
|
||
"Epoch: 13/150.. Training Loss: 0.086.. Test Loss: 0.063.. Test Accuracy: 0.983\n",
|
||
"Epoch: 14/150.. Training Loss: 0.067.. Test Loss: 0.059.. Test Accuracy: 0.983\n",
|
||
"Epoch: 15/150.. Training Loss: 0.062.. Test Loss: 0.063.. Test Accuracy: 0.983\n",
|
||
"Epoch: 16/150.. Training Loss: 0.062.. Test Loss: 0.067.. Test Accuracy: 0.983\n",
|
||
"Epoch: 17/150.. Training Loss: 0.061.. Test Loss: 0.068.. Test Accuracy: 0.983\n",
|
||
"Epoch: 18/150.. Training Loss: 0.058.. Test Loss: 0.067.. Test Accuracy: 0.983\n",
|
||
"Epoch: 19/150.. Training Loss: 0.053.. Test Loss: 0.064.. Test Accuracy: 0.983\n",
|
||
"Epoch: 20/150.. Training Loss: 0.047.. Test Loss: 0.061.. Test Accuracy: 0.983\n",
|
||
"Epoch: 21/150.. Training Loss: 0.041.. Test Loss: 0.057.. Test Accuracy: 0.983\n",
|
||
"Epoch: 22/150.. Training Loss: 0.037.. Test Loss: 0.054.. Test Accuracy: 0.983\n",
|
||
"Epoch: 23/150.. Training Loss: 0.033.. Test Loss: 0.051.. Test Accuracy: 0.983\n",
|
||
"Epoch: 24/150.. Training Loss: 0.030.. Test Loss: 0.048.. Test Accuracy: 0.983\n",
|
||
"Epoch: 25/150.. Training Loss: 0.027.. Test Loss: 0.045.. Test Accuracy: 0.983\n",
|
||
"Epoch: 26/150.. Training Loss: 0.025.. Test Loss: 0.044.. Test Accuracy: 0.983\n",
|
||
"Epoch: 27/150.. Training Loss: 0.023.. Test Loss: 0.042.. Test Accuracy: 0.983\n",
|
||
"Epoch: 28/150.. Training Loss: 0.021.. Test Loss: 0.041.. Test Accuracy: 0.983\n",
|
||
"Epoch: 29/150.. Training Loss: 0.020.. Test Loss: 0.042.. Test Accuracy: 0.983\n",
|
||
"Epoch: 30/150.. Training Loss: 0.019.. Test Loss: 0.043.. Test Accuracy: 0.983\n",
|
||
"Epoch: 31/150.. Training Loss: 0.017.. Test Loss: 0.044.. Test Accuracy: 0.983\n",
|
||
"Epoch: 32/150.. Training Loss: 0.016.. Test Loss: 0.047.. Test Accuracy: 0.983\n",
|
||
"Epoch: 33/150.. Training Loss: 0.015.. Test Loss: 0.050.. Test Accuracy: 0.993\n",
|
||
"Epoch: 34/150.. Training Loss: 0.013.. Test Loss: 0.053.. Test Accuracy: 0.997\n",
|
||
"Epoch: 35/150.. Training Loss: 0.012.. Test Loss: 0.056.. Test Accuracy: 0.997\n",
|
||
"Epoch: 36/150.. Training Loss: 0.008.. Test Loss: 0.058.. Test Accuracy: 0.997\n",
|
||
"Epoch: 37/150.. Training Loss: 0.003.. Test Loss: 0.062.. Test Accuracy: 0.996\n",
|
||
"Epoch: 38/150.. Training Loss: 0.000.. Test Loss: 0.069.. Test Accuracy: 0.996\n",
|
||
"Epoch: 39/150.. Training Loss: 0.000.. Test Loss: 0.086.. Test Accuracy: 0.995\n",
|
||
"Epoch: 40/150.. Training Loss: 0.001.. Test Loss: 0.104.. Test Accuracy: 0.995\n",
|
||
"Epoch: 41/150.. Training Loss: 0.001.. Test Loss: 0.122.. Test Accuracy: 0.995\n",
|
||
"Epoch: 42/150.. Training Loss: 0.001.. Test Loss: 0.138.. Test Accuracy: 0.996\n",
|
||
"Epoch: 43/150.. Training Loss: 0.002.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 44/150.. Training Loss: 0.002.. Test Loss: 0.169.. Test Accuracy: 0.996\n",
|
||
"Epoch: 45/150.. Training Loss: 0.001.. Test Loss: 0.181.. Test Accuracy: 0.996\n",
|
||
"Epoch: 46/150.. Training Loss: 0.000.. Test Loss: 0.192.. Test Accuracy: 0.997\n",
|
||
"Epoch: 47/150.. Training Loss: 0.000.. Test Loss: 0.214.. Test Accuracy: 0.997\n",
|
||
"Epoch: 48/150.. Training Loss: 0.000.. Test Loss: 0.236.. Test Accuracy: 0.997\n",
|
||
"Epoch: 49/150.. Training Loss: 0.002.. Test Loss: 0.182.. Test Accuracy: 0.997\n",
|
||
"Epoch: 50/150.. Training Loss: 0.000.. Test Loss: 0.129.. Test Accuracy: 0.997\n",
|
||
"Epoch: 51/150.. Training Loss: 0.000.. Test Loss: 0.101.. Test Accuracy: 0.996\n",
|
||
"Epoch: 52/150.. Training Loss: 0.000.. Test Loss: 0.083.. Test Accuracy: 0.996\n",
|
||
"Epoch: 53/150.. Training Loss: 0.000.. Test Loss: 0.077.. Test Accuracy: 0.995\n",
|
||
"Epoch: 54/150.. Training Loss: 0.000.. Test Loss: 0.072.. Test Accuracy: 0.995\n",
|
||
"Epoch: 55/150.. Training Loss: 0.000.. Test Loss: 0.070.. Test Accuracy: 0.995\n",
|
||
"Epoch: 56/150.. Training Loss: 0.001.. Test Loss: 0.077.. Test Accuracy: 0.995\n",
|
||
"Epoch: 57/150.. Training Loss: 0.001.. Test Loss: 0.080.. Test Accuracy: 0.995\n",
|
||
"Epoch: 58/150.. Training Loss: 0.000.. Test Loss: 0.080.. Test Accuracy: 0.995\n",
|
||
"Epoch: 59/150.. Training Loss: 0.000.. Test Loss: 0.079.. Test Accuracy: 0.995\n",
|
||
"Epoch: 60/150.. Training Loss: 0.000.. Test Loss: 0.078.. Test Accuracy: 0.995\n",
|
||
"Epoch: 61/150.. Training Loss: 0.000.. Test Loss: 0.078.. Test Accuracy: 0.995\n",
|
||
"Epoch: 62/150.. Training Loss: 0.000.. Test Loss: 0.079.. Test Accuracy: 0.995\n",
|
||
"Epoch: 63/150.. Training Loss: 0.000.. Test Loss: 0.081.. Test Accuracy: 0.995\n",
|
||
"Epoch: 64/150.. Training Loss: 0.000.. Test Loss: 0.084.. Test Accuracy: 0.995\n",
|
||
"Epoch: 65/150.. Training Loss: 0.000.. Test Loss: 0.089.. Test Accuracy: 0.995\n",
|
||
"Epoch: 66/150.. Training Loss: 0.000.. Test Loss: 0.095.. Test Accuracy: 0.995\n",
|
||
"Epoch: 67/150.. Training Loss: 0.000.. Test Loss: 0.101.. Test Accuracy: 0.995\n",
|
||
"Epoch: 68/150.. Training Loss: 0.000.. Test Loss: 0.107.. Test Accuracy: 0.995\n",
|
||
"Epoch: 69/150.. Training Loss: 0.000.. Test Loss: 0.112.. Test Accuracy: 0.995\n",
|
||
"Epoch: 70/150.. Training Loss: 0.000.. Test Loss: 0.116.. Test Accuracy: 0.995\n",
|
||
"Epoch: 71/150.. Training Loss: 0.000.. Test Loss: 0.120.. Test Accuracy: 0.995\n",
|
||
"Epoch: 72/150.. Training Loss: 0.000.. Test Loss: 0.124.. Test Accuracy: 0.995\n",
|
||
"Epoch: 73/150.. Training Loss: 0.000.. Test Loss: 0.127.. Test Accuracy: 0.995\n",
|
||
"Epoch: 74/150.. Training Loss: 0.000.. Test Loss: 0.129.. Test Accuracy: 0.995\n",
|
||
"Epoch: 75/150.. Training Loss: 0.000.. Test Loss: 0.132.. Test Accuracy: 0.995\n",
|
||
"Epoch: 76/150.. Training Loss: 0.000.. Test Loss: 0.134.. Test Accuracy: 0.996\n",
|
||
"Epoch: 77/150.. Training Loss: 0.000.. Test Loss: 0.136.. Test Accuracy: 0.996\n",
|
||
"Epoch: 78/150.. Training Loss: 0.000.. Test Loss: 0.138.. Test Accuracy: 0.996\n",
|
||
"Epoch: 79/150.. Training Loss: 0.000.. Test Loss: 0.139.. Test Accuracy: 0.996\n",
|
||
"Epoch: 80/150.. Training Loss: 0.000.. Test Loss: 0.141.. Test Accuracy: 0.996\n",
|
||
"Epoch: 81/150.. Training Loss: 0.000.. Test Loss: 0.142.. Test Accuracy: 0.996\n",
|
||
"Epoch: 82/150.. Training Loss: 0.000.. Test Loss: 0.144.. Test Accuracy: 0.996\n",
|
||
"Epoch: 83/150.. Training Loss: 0.000.. Test Loss: 0.145.. Test Accuracy: 0.996\n",
|
||
"Epoch: 84/150.. Training Loss: 0.000.. Test Loss: 0.146.. Test Accuracy: 0.996\n",
|
||
"Epoch: 85/150.. Training Loss: 0.000.. Test Loss: 0.147.. Test Accuracy: 0.996\n",
|
||
"Epoch: 86/150.. Training Loss: 0.000.. Test Loss: 0.148.. Test Accuracy: 0.996\n",
|
||
"Epoch: 87/150.. Training Loss: 0.000.. Test Loss: 0.148.. Test Accuracy: 0.996\n",
|
||
"Epoch: 88/150.. Training Loss: 0.000.. Test Loss: 0.149.. Test Accuracy: 0.996\n",
|
||
"Epoch: 89/150.. Training Loss: 0.000.. Test Loss: 0.150.. Test Accuracy: 0.996\n",
|
||
"Epoch: 90/150.. Training Loss: 0.000.. Test Loss: 0.150.. Test Accuracy: 0.996\n",
|
||
"Epoch: 91/150.. Training Loss: 0.000.. Test Loss: 0.151.. Test Accuracy: 0.996\n",
|
||
"Epoch: 92/150.. Training Loss: 0.000.. Test Loss: 0.151.. Test Accuracy: 0.996\n",
|
||
"Epoch: 93/150.. Training Loss: 0.000.. Test Loss: 0.152.. Test Accuracy: 0.996\n",
|
||
"Epoch: 94/150.. Training Loss: 0.000.. Test Loss: 0.152.. Test Accuracy: 0.996\n",
|
||
"Epoch: 95/150.. Training Loss: 0.000.. Test Loss: 0.152.. Test Accuracy: 0.996\n",
|
||
"Epoch: 96/150.. Training Loss: 0.000.. Test Loss: 0.153.. Test Accuracy: 0.996\n",
|
||
"Epoch: 97/150.. Training Loss: 0.000.. Test Loss: 0.153.. Test Accuracy: 0.996\n",
|
||
"Epoch: 98/150.. Training Loss: 0.000.. Test Loss: 0.153.. Test Accuracy: 0.996\n",
|
||
"Epoch: 99/150.. Training Loss: 0.000.. Test Loss: 0.153.. Test Accuracy: 0.996\n",
|
||
"Epoch: 100/150.. Training Loss: 0.000.. Test Loss: 0.153.. Test Accuracy: 0.996\n",
|
||
"Epoch: 101/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 102/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 103/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 104/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 105/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 106/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 107/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 108/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 109/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 110/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 111/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 112/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 113/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 114/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 115/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 116/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 117/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 118/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 119/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 120/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 121/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 122/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 123/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 124/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 125/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 126/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 127/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 128/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 129/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 130/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 131/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 132/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 133/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 134/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 135/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 136/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 137/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 138/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 139/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 140/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 141/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 142/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 143/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 144/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 145/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 146/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 147/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 148/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 149/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"Epoch: 150/150.. Training Loss: 0.000.. Test Loss: 0.155.. Test Accuracy: 0.996\n",
|
||
"\n",
|
||
"Mixed precision:\n",
|
||
"Total execution time = 21.202 sec\n",
|
||
"Max memory used by tensors = 2485789184 bytes\n"
|
||
]
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Default model\n",
|
||
"model.eval()\n",
|
||
"predictions = []\n",
|
||
"output = model(x_test)\n",
|
||
"ps = torch.exp(output)\n",
|
||
"top_p, top_class = ps.topk(1, dim=1)\n",
|
||
"predictions = np.array(top_class.cpu().detach())\n",
|
||
"y_pred = []\n",
|
||
"for d in predictions:\n",
|
||
" y_pred.append(d)\n",
|
||
"y_true = []\n",
|
||
"for d in y_test:\n",
|
||
" y_true.append(int(d))\n",
|
||
"y_true\n",
|
||
"print('F-score: ', f1_score(y_true, y_pred, average='macro'))\n",
|
||
"\n",
|
||
"print('Accuracy: ', accuracy_score(y_true, y_pred))"
|
||
],
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "FOx-H5UI5Hxa",
|
||
"outputId": "9ce6a2bf-26e6-47e6-f946-20c65f202768"
|
||
},
|
||
"execution_count": 250,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"F-score: 0.9845942906441127\n",
|
||
"Accuracy: 0.999\n"
|
||
]
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Mixed precision model\n",
|
||
"use_amp = True\n",
|
||
"\n",
|
||
"\n",
|
||
"model = nn.Sequential(\n",
|
||
" nn.Linear(x_train.shape[1], size),\n",
|
||
" nn.ReLU(),\n",
|
||
" # nn.Linear(64, data_train[\"fraudulent\"].nunique()),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
"\n",
|
||
" nn.Linear(size, size),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Linear(size, data_train[\"fraudulent\"].nunique()),\n",
|
||
" \n",
|
||
" nn.LogSoftmax(dim=1))\n",
|
||
"model.cuda()\n",
|
||
"# Define the loss\n",
|
||
"criterion = nn.NLLLoss() # Forward pass, log\n",
|
||
"logps = model(x_train) # Calculate the loss with the logits and the labels\n",
|
||
"loss = criterion(logps, y_train)\n",
|
||
"loss.backward() # Optimizers need parameters to optimize and a learning rate\n",
|
||
"optimizer = optim.Adam(model.parameters(), lr=0.002)\n",
|
||
"\n",
|
||
"train_losses = []\n",
|
||
"test_losses = []\n",
|
||
"test_accuracies = []\n",
|
||
"scaler = torch.cuda.amp.GradScaler(enabled=use_amp)\n",
|
||
"start_timer()\n",
|
||
"for e in range(epochs):\n",
|
||
" optimizer.zero_grad()\n",
|
||
" with torch.cuda.amp.autocast(enabled=use_amp):\n",
|
||
" output = model.forward(x_train)\n",
|
||
" loss = criterion(output, y_train)\n",
|
||
" scaler.scale(loss).backward()\n",
|
||
" train_loss = loss.item()\n",
|
||
" train_losses.append(train_loss)\n",
|
||
" scaler.step(optimizer)\n",
|
||
" scaler.update()\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
" # Turn off gradients for validation, saves memory and computations\n",
|
||
" with torch.no_grad():\n",
|
||
" model.eval()\n",
|
||
" log_ps = model(x_dev)\n",
|
||
" test_loss = criterion(log_ps, y_dev)\n",
|
||
" test_losses.append(test_loss)\n",
|
||
"\n",
|
||
" ps = torch.exp(log_ps)\n",
|
||
" top_p, top_class = ps.topk(1, dim=1)\n",
|
||
" equals = top_class == y_dev.view(*top_class.shape)\n",
|
||
" test_accuracy = torch.mean(equals.float())\n",
|
||
" test_accuracies.append(test_accuracy)\n",
|
||
"\n",
|
||
" model.train()\n",
|
||
"\n",
|
||
" print(f\"Epoch: {e + 1}/{epochs}.. \",\n",
|
||
" f\"Training Loss: {train_loss:.3f}.. \",\n",
|
||
" f\"Test Loss: {test_loss:.3f}.. \",\n",
|
||
" f\"Test Accuracy: {test_accuracy:.3f}\")\n",
|
||
"end_timer_and_print(\"Mixed precision:\")"
|
||
],
|
||
"metadata": {
|
||
"id": "eyBj73tj0D8R",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "bdbc1189-4d98-41e4-8769-2abd0a35e718"
|
||
},
|
||
"execution_count": 251,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stderr",
|
||
"text": [
|
||
"/usr/local/lib/python3.7/dist-packages/torch/cuda/memory.py:274: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
|
||
" FutureWarning)\n"
|
||
]
|
||
},
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"Epoch: 1/150.. Training Loss: 0.729.. Test Loss: 0.643.. Test Accuracy: 0.983\n",
|
||
"Epoch: 2/150.. Training Loss: 0.644.. Test Loss: 0.518.. Test Accuracy: 0.983\n",
|
||
"Epoch: 3/150.. Training Loss: 0.519.. Test Loss: 0.245.. Test Accuracy: 0.983\n",
|
||
"Epoch: 4/150.. Training Loss: 0.249.. Test Loss: 0.087.. Test Accuracy: 0.983\n",
|
||
"Epoch: 5/150.. Training Loss: 0.098.. Test Loss: 0.171.. Test Accuracy: 0.983\n",
|
||
"Epoch: 6/150.. Training Loss: 0.187.. Test Loss: 0.178.. Test Accuracy: 0.983\n",
|
||
"Epoch: 7/150.. Training Loss: 0.191.. Test Loss: 0.135.. Test Accuracy: 0.983\n",
|
||
"Epoch: 8/150.. Training Loss: 0.145.. Test Loss: 0.093.. Test Accuracy: 0.983\n",
|
||
"Epoch: 9/150.. Training Loss: 0.101.. Test Loss: 0.070.. Test Accuracy: 0.983\n",
|
||
"Epoch: 10/150.. Training Loss: 0.077.. Test Loss: 0.088.. Test Accuracy: 0.983\n",
|
||
"Epoch: 11/150.. Training Loss: 0.093.. Test Loss: 0.100.. Test Accuracy: 0.983\n",
|
||
"Epoch: 12/150.. Training Loss: 0.104.. Test Loss: 0.080.. Test Accuracy: 0.983\n",
|
||
"Epoch: 13/150.. Training Loss: 0.085.. Test Loss: 0.061.. Test Accuracy: 0.983\n",
|
||
"Epoch: 14/150.. Training Loss: 0.065.. Test Loss: 0.059.. Test Accuracy: 0.983\n",
|
||
"Epoch: 15/150.. Training Loss: 0.061.. Test Loss: 0.063.. Test Accuracy: 0.983\n",
|
||
"Epoch: 16/150.. Training Loss: 0.062.. Test Loss: 0.066.. Test Accuracy: 0.983\n",
|
||
"Epoch: 17/150.. Training Loss: 0.060.. Test Loss: 0.066.. Test Accuracy: 0.983\n",
|
||
"Epoch: 18/150.. Training Loss: 0.056.. Test Loss: 0.064.. Test Accuracy: 0.983\n",
|
||
"Epoch: 19/150.. Training Loss: 0.051.. Test Loss: 0.060.. Test Accuracy: 0.983\n",
|
||
"Epoch: 20/150.. Training Loss: 0.044.. Test Loss: 0.057.. Test Accuracy: 0.983\n",
|
||
"Epoch: 21/150.. Training Loss: 0.039.. Test Loss: 0.053.. Test Accuracy: 0.983\n",
|
||
"Epoch: 22/150.. Training Loss: 0.034.. Test Loss: 0.050.. Test Accuracy: 0.983\n",
|
||
"Epoch: 23/150.. Training Loss: 0.031.. Test Loss: 0.047.. Test Accuracy: 0.983\n",
|
||
"Epoch: 24/150.. Training Loss: 0.027.. Test Loss: 0.045.. Test Accuracy: 0.983\n",
|
||
"Epoch: 25/150.. Training Loss: 0.025.. Test Loss: 0.043.. Test Accuracy: 0.983\n",
|
||
"Epoch: 26/150.. Training Loss: 0.022.. Test Loss: 0.041.. Test Accuracy: 0.983\n",
|
||
"Epoch: 27/150.. Training Loss: 0.020.. Test Loss: 0.040.. Test Accuracy: 0.983\n",
|
||
"Epoch: 28/150.. Training Loss: 0.019.. Test Loss: 0.040.. Test Accuracy: 0.983\n",
|
||
"Epoch: 29/150.. Training Loss: 0.017.. Test Loss: 0.040.. Test Accuracy: 0.983\n",
|
||
"Epoch: 30/150.. Training Loss: 0.016.. Test Loss: 0.041.. Test Accuracy: 0.983\n",
|
||
"Epoch: 31/150.. Training Loss: 0.015.. Test Loss: 0.043.. Test Accuracy: 0.994\n",
|
||
"Epoch: 32/150.. Training Loss: 0.013.. Test Loss: 0.045.. Test Accuracy: 0.996\n",
|
||
"Epoch: 33/150.. Training Loss: 0.012.. Test Loss: 0.047.. Test Accuracy: 0.996\n",
|
||
"Epoch: 34/150.. Training Loss: 0.009.. Test Loss: 0.049.. Test Accuracy: 0.996\n",
|
||
"Epoch: 35/150.. Training Loss: 0.005.. Test Loss: 0.054.. Test Accuracy: 0.996\n",
|
||
"Epoch: 36/150.. Training Loss: 0.001.. Test Loss: 0.064.. Test Accuracy: 0.996\n",
|
||
"Epoch: 37/150.. Training Loss: 0.000.. Test Loss: 0.077.. Test Accuracy: 0.996\n",
|
||
"Epoch: 38/150.. Training Loss: 0.001.. Test Loss: 0.094.. Test Accuracy: 0.995\n",
|
||
"Epoch: 39/150.. Training Loss: 0.001.. Test Loss: 0.113.. Test Accuracy: 0.995\n",
|
||
"Epoch: 40/150.. Training Loss: 0.001.. Test Loss: 0.131.. Test Accuracy: 0.995\n",
|
||
"Epoch: 41/150.. Training Loss: 0.002.. Test Loss: 0.144.. Test Accuracy: 0.996\n",
|
||
"Epoch: 42/150.. Training Loss: 0.002.. Test Loss: 0.158.. Test Accuracy: 0.996\n",
|
||
"Epoch: 43/150.. Training Loss: 0.001.. Test Loss: 0.170.. Test Accuracy: 0.996\n",
|
||
"Epoch: 44/150.. Training Loss: 0.001.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 45/150.. Training Loss: 0.000.. Test Loss: 0.195.. Test Accuracy: 0.997\n",
|
||
"Epoch: 46/150.. Training Loss: 0.000.. Test Loss: 0.216.. Test Accuracy: 0.997\n",
|
||
"Epoch: 47/150.. Training Loss: 0.000.. Test Loss: 0.237.. Test Accuracy: 0.997\n",
|
||
"Epoch: 48/150.. Training Loss: 0.002.. Test Loss: 0.174.. Test Accuracy: 0.997\n",
|
||
"Epoch: 49/150.. Training Loss: 0.000.. Test Loss: 0.126.. Test Accuracy: 0.997\n",
|
||
"Epoch: 50/150.. Training Loss: 0.000.. Test Loss: 0.090.. Test Accuracy: 0.997\n",
|
||
"Epoch: 51/150.. Training Loss: 0.000.. Test Loss: 0.062.. Test Accuracy: 0.997\n",
|
||
"Epoch: 52/150.. Training Loss: 0.000.. Test Loss: 0.045.. Test Accuracy: 0.996\n",
|
||
"Epoch: 53/150.. Training Loss: 0.000.. Test Loss: 0.035.. Test Accuracy: 0.996\n",
|
||
"Epoch: 54/150.. Training Loss: 0.000.. Test Loss: 0.031.. Test Accuracy: 0.996\n",
|
||
"Epoch: 55/150.. Training Loss: 0.000.. Test Loss: 0.042.. Test Accuracy: 0.996\n",
|
||
"Epoch: 56/150.. Training Loss: 0.000.. Test Loss: 0.053.. Test Accuracy: 0.996\n",
|
||
"Epoch: 57/150.. Training Loss: 0.000.. Test Loss: 0.063.. Test Accuracy: 0.996\n",
|
||
"Epoch: 58/150.. Training Loss: 0.000.. Test Loss: 0.072.. Test Accuracy: 0.996\n",
|
||
"Epoch: 59/150.. Training Loss: 0.000.. Test Loss: 0.081.. Test Accuracy: 0.996\n",
|
||
"Epoch: 60/150.. Training Loss: 0.000.. Test Loss: 0.089.. Test Accuracy: 0.996\n",
|
||
"Epoch: 61/150.. Training Loss: 0.000.. Test Loss: 0.097.. Test Accuracy: 0.996\n",
|
||
"Epoch: 62/150.. Training Loss: 0.000.. Test Loss: 0.104.. Test Accuracy: 0.996\n",
|
||
"Epoch: 63/150.. Training Loss: 0.000.. Test Loss: 0.110.. Test Accuracy: 0.996\n",
|
||
"Epoch: 64/150.. Training Loss: 0.000.. Test Loss: 0.117.. Test Accuracy: 0.996\n",
|
||
"Epoch: 65/150.. Training Loss: 0.000.. Test Loss: 0.122.. Test Accuracy: 0.996\n",
|
||
"Epoch: 66/150.. Training Loss: 0.000.. Test Loss: 0.127.. Test Accuracy: 0.996\n",
|
||
"Epoch: 67/150.. Training Loss: 0.000.. Test Loss: 0.132.. Test Accuracy: 0.996\n",
|
||
"Epoch: 68/150.. Training Loss: 0.000.. Test Loss: 0.136.. Test Accuracy: 0.996\n",
|
||
"Epoch: 69/150.. Training Loss: 0.000.. Test Loss: 0.140.. Test Accuracy: 0.996\n",
|
||
"Epoch: 70/150.. Training Loss: 0.000.. Test Loss: 0.143.. Test Accuracy: 0.996\n",
|
||
"Epoch: 71/150.. Training Loss: 0.000.. Test Loss: 0.147.. Test Accuracy: 0.996\n",
|
||
"Epoch: 72/150.. Training Loss: 0.000.. Test Loss: 0.149.. Test Accuracy: 0.996\n",
|
||
"Epoch: 73/150.. Training Loss: 0.000.. Test Loss: 0.152.. Test Accuracy: 0.996\n",
|
||
"Epoch: 74/150.. Training Loss: 0.000.. Test Loss: 0.154.. Test Accuracy: 0.996\n",
|
||
"Epoch: 75/150.. Training Loss: 0.000.. Test Loss: 0.156.. Test Accuracy: 0.996\n",
|
||
"Epoch: 76/150.. Training Loss: 0.000.. Test Loss: 0.158.. Test Accuracy: 0.996\n",
|
||
"Epoch: 77/150.. Training Loss: 0.000.. Test Loss: 0.160.. Test Accuracy: 0.996\n",
|
||
"Epoch: 78/150.. Training Loss: 0.000.. Test Loss: 0.162.. Test Accuracy: 0.996\n",
|
||
"Epoch: 79/150.. Training Loss: 0.000.. Test Loss: 0.163.. Test Accuracy: 0.996\n",
|
||
"Epoch: 80/150.. Training Loss: 0.000.. Test Loss: 0.164.. Test Accuracy: 0.996\n",
|
||
"Epoch: 81/150.. Training Loss: 0.000.. Test Loss: 0.166.. Test Accuracy: 0.996\n",
|
||
"Epoch: 82/150.. Training Loss: 0.000.. Test Loss: 0.167.. Test Accuracy: 0.996\n",
|
||
"Epoch: 83/150.. Training Loss: 0.000.. Test Loss: 0.168.. Test Accuracy: 0.996\n",
|
||
"Epoch: 84/150.. Training Loss: 0.000.. Test Loss: 0.169.. Test Accuracy: 0.996\n",
|
||
"Epoch: 85/150.. Training Loss: 0.000.. Test Loss: 0.169.. Test Accuracy: 0.996\n",
|
||
"Epoch: 86/150.. Training Loss: 0.000.. Test Loss: 0.170.. Test Accuracy: 0.996\n",
|
||
"Epoch: 87/150.. Training Loss: 0.000.. Test Loss: 0.171.. Test Accuracy: 0.996\n",
|
||
"Epoch: 88/150.. Training Loss: 0.000.. Test Loss: 0.171.. Test Accuracy: 0.996\n",
|
||
"Epoch: 89/150.. Training Loss: 0.000.. Test Loss: 0.172.. Test Accuracy: 0.996\n",
|
||
"Epoch: 90/150.. Training Loss: 0.000.. Test Loss: 0.172.. Test Accuracy: 0.996\n",
|
||
"Epoch: 91/150.. Training Loss: 0.000.. Test Loss: 0.173.. Test Accuracy: 0.996\n",
|
||
"Epoch: 92/150.. Training Loss: 0.000.. Test Loss: 0.173.. Test Accuracy: 0.996\n",
|
||
"Epoch: 93/150.. Training Loss: 0.000.. Test Loss: 0.174.. Test Accuracy: 0.996\n",
|
||
"Epoch: 94/150.. Training Loss: 0.000.. Test Loss: 0.174.. Test Accuracy: 0.996\n",
|
||
"Epoch: 95/150.. Training Loss: 0.000.. Test Loss: 0.174.. Test Accuracy: 0.996\n",
|
||
"Epoch: 96/150.. Training Loss: 0.000.. Test Loss: 0.175.. Test Accuracy: 0.996\n",
|
||
"Epoch: 97/150.. Training Loss: 0.000.. Test Loss: 0.175.. Test Accuracy: 0.996\n",
|
||
"Epoch: 98/150.. Training Loss: 0.000.. Test Loss: 0.175.. Test Accuracy: 0.996\n",
|
||
"Epoch: 99/150.. Training Loss: 0.000.. Test Loss: 0.175.. Test Accuracy: 0.996\n",
|
||
"Epoch: 100/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 101/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 102/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 103/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 104/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 105/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 106/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 107/150.. Training Loss: 0.000.. Test Loss: 0.176.. Test Accuracy: 0.996\n",
|
||
"Epoch: 108/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 109/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 110/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 111/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 112/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 113/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 114/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 115/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 116/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 117/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 118/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 119/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 120/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 121/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 122/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 123/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 124/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 125/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 126/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 127/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 128/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 129/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 130/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 131/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 132/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 133/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 134/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 135/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 136/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 137/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 138/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 139/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 140/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 141/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 142/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 143/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 144/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 145/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 146/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 147/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 148/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 149/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"Epoch: 150/150.. Training Loss: 0.000.. Test Loss: 0.177.. Test Accuracy: 0.996\n",
|
||
"\n",
|
||
"Mixed precision:\n",
|
||
"Total execution time = 7.507 sec\n",
|
||
"Max memory used by tensors = 2737311232 bytes\n"
|
||
]
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Mixed precision model\n",
|
||
"model.eval()\n",
|
||
"predictions = []\n",
|
||
"output = model(x_test)\n",
|
||
"ps = torch.exp(output)\n",
|
||
"top_p, top_class = ps.topk(1, dim=1)\n",
|
||
"predictions = np.array(top_class.cpu().detach())\n",
|
||
"y_pred = []\n",
|
||
"for d in predictions:\n",
|
||
" y_pred.append(d)\n",
|
||
"y_true = []\n",
|
||
"for d in y_test:\n",
|
||
" y_true.append(int(d))\n",
|
||
"y_true\n",
|
||
"print('F-score: ', f1_score(y_true, y_pred, average='macro'))\n",
|
||
"\n",
|
||
"print('Accuracy: ', accuracy_score(y_true, y_pred))"
|
||
],
|
||
"metadata": {
|
||
"id": "UCNKR4tE1Ign",
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"outputId": "6e981ca3-f1ad-4b1d-f601-64a50432f425"
|
||
},
|
||
"execution_count": 252,
|
||
"outputs": [
|
||
{
|
||
"output_type": "stream",
|
||
"name": "stdout",
|
||
"text": [
|
||
"F-score: 0.9845942906441127\n",
|
||
"Accuracy: 0.999\n"
|
||
]
|
||
}
|
||
]
|
||
}
|
||
]
|
||
} |