727 lines
22 KiB
Plaintext
727 lines
22 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "74100403-147c-42cd-8285-e30778c0fb66",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import torch\n",
|
||
"import csv\n",
|
||
"import lzma\n",
|
||
"import gensim.downloader\n",
|
||
"from nltk import word_tokenize"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "cbe60d7b-850e-4838-b4ce-672f13bf2bb2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "bf211ece-e27a-4119-a1b9-9a9a610cfb46",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def predict_year(x, path_out, model):\n",
|
||
" results = model.predict(x)\n",
|
||
" with open(path_out, 'wt') as file:\n",
|
||
" for r in results:\n",
|
||
" file.write(str(r) + '\\n') "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "1ec57d97-a852-490e-8da4-d1e4c9676cd6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def read_file(filename):\n",
|
||
" result = []\n",
|
||
" with open(filename, 'r', encoding=\"utf-8\") as file:\n",
|
||
" for line in file:\n",
|
||
" text = line.split(\"\\t\")[0].strip()\n",
|
||
" result.append(text)\n",
|
||
" return result"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "86fbfb79-76e7-49f5-b722-2827f93cb03f",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>0</th>\n",
|
||
" <th>1</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>have you had an medical issues recently?</td>\n",
|
||
" <td>1335187994</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>It's supposedly aluminum, barium, and strontiu...</td>\n",
|
||
" <td>1346187161</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>Nobel prizes don't make you rich.</td>\n",
|
||
" <td>1337160218</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>I came for the article, I stayed for the doctor.</td>\n",
|
||
" <td>1277674344</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>you resorted to insults AND got owned directly...</td>\n",
|
||
" <td>1348538535</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199995</th>\n",
|
||
" <td>It's really sad. My sister used to believe tha...</td>\n",
|
||
" <td>1334111989</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199996</th>\n",
|
||
" <td>I don't mean it in a dickish way, I'm being se...</td>\n",
|
||
" <td>1322700456</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199997</th>\n",
|
||
" <td>Fair enough, I stand corrected.</td>\n",
|
||
" <td>1354646212</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199998</th>\n",
|
||
" <td>Right. Scientists tend to think and conclude l...</td>\n",
|
||
" <td>1348777201</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199999</th>\n",
|
||
" <td>Because they are illiterate</td>\n",
|
||
" <td>1249579722</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>200000 rows × 2 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" 0 1\n",
|
||
"0 have you had an medical issues recently? 1335187994\n",
|
||
"1 It's supposedly aluminum, barium, and strontiu... 1346187161\n",
|
||
"2 Nobel prizes don't make you rich. 1337160218\n",
|
||
"3 I came for the article, I stayed for the doctor. 1277674344\n",
|
||
"4 you resorted to insults AND got owned directly... 1348538535\n",
|
||
"... ... ...\n",
|
||
"199995 It's really sad. My sister used to believe tha... 1334111989\n",
|
||
"199996 I don't mean it in a dickish way, I'm being se... 1322700456\n",
|
||
"199997 Fair enough, I stand corrected. 1354646212\n",
|
||
"199998 Right. Scientists tend to think and conclude l... 1348777201\n",
|
||
"199999 Because they are illiterate 1249579722\n",
|
||
"\n",
|
||
"[200000 rows x 2 columns]"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"x_train = pd.read_table('train/in.tsv', sep='\\t', header=None, quoting=3)\n",
|
||
"x_train = x_train[0:200000]\n",
|
||
"x_train"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "8960c975-f756-4e36-a1ce-e9fd5fdf8fe3",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>0</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199995</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199996</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199997</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199998</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>199999</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>200000 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" 0\n",
|
||
"0 1\n",
|
||
"1 0\n",
|
||
"2 0\n",
|
||
"3 0\n",
|
||
"4 0\n",
|
||
"... ..\n",
|
||
"199995 0\n",
|
||
"199996 0\n",
|
||
"199997 1\n",
|
||
"199998 1\n",
|
||
"199999 0\n",
|
||
"\n",
|
||
"[200000 rows x 1 columns]"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"with open('train/expected.tsv', 'r', encoding='utf8') as file:\n",
|
||
" y_train = pd.read_csv(file, sep='\\t', header=None)\n",
|
||
"y_train = y_train[0:200000]\n",
|
||
"y_train"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "6b27e6ce-e9fd-41a1-aacf-53a5fde0a7c1",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>0</th>\n",
|
||
" <th>1</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>In which case, tell them I'm in work, or dead,...</td>\n",
|
||
" <td>1328302967</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>Put me down as another for Mysterious Universe...</td>\n",
|
||
" <td>1347836881</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>The military of any country would never admit ...</td>\n",
|
||
" <td>1331905826</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>An example would have been more productive tha...</td>\n",
|
||
" <td>1315584834</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>sorry, but the authors of this article admit t...</td>\n",
|
||
" <td>1347389166</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5267</th>\n",
|
||
" <td>Your fault for going at all. That's how we get...</td>\n",
|
||
" <td>1308176634</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5268</th>\n",
|
||
" <td>EVP....that's a shot in the GH drinking game.</td>\n",
|
||
" <td>1354408646</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5269</th>\n",
|
||
" <td>i think a good hard massage is good for you. t...</td>\n",
|
||
" <td>1305726318</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5270</th>\n",
|
||
" <td>Interesting theory. Makes my imagination run w...</td>\n",
|
||
" <td>1339839088</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5271</th>\n",
|
||
" <td>Tampering of candy? More like cooking somethin...</td>\n",
|
||
" <td>1320262659</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>5272 rows × 2 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" 0 1\n",
|
||
"0 In which case, tell them I'm in work, or dead,... 1328302967\n",
|
||
"1 Put me down as another for Mysterious Universe... 1347836881\n",
|
||
"2 The military of any country would never admit ... 1331905826\n",
|
||
"3 An example would have been more productive tha... 1315584834\n",
|
||
"4 sorry, but the authors of this article admit t... 1347389166\n",
|
||
"... ... ...\n",
|
||
"5267 Your fault for going at all. That's how we get... 1308176634\n",
|
||
"5268 EVP....that's a shot in the GH drinking game. 1354408646\n",
|
||
"5269 i think a good hard massage is good for you. t... 1305726318\n",
|
||
"5270 Interesting theory. Makes my imagination run w... 1339839088\n",
|
||
"5271 Tampering of candy? More like cooking somethin... 1320262659\n",
|
||
"\n",
|
||
"[5272 rows x 2 columns]"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"with open('dev-0/in.tsv', 'r', encoding='utf8') as file:\n",
|
||
" x_dev = pd.read_csv(file, sep='\\t', header=None)\n",
|
||
"x_dev"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "99ae526d-9b7c-493f-be4f-f95b1c8f4b81",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>0</th>\n",
|
||
" <th>1</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>Gentleman, I believe we can agree that this is...</td>\n",
|
||
" <td>1304170330</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>The problem is that it will just turn it r/nos...</td>\n",
|
||
" <td>1353763204</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>Well, according to some Christian apologists, ...</td>\n",
|
||
" <td>1336314173</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>Don't know if this is what you are looking for...</td>\n",
|
||
" <td>1348860314</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>I respect what you're saying completely. I jus...</td>\n",
|
||
" <td>1341285952</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5147</th>\n",
|
||
" <td>GAMBIT</td>\n",
|
||
" <td>1326441107</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5148</th>\n",
|
||
" <td>&gt;Joe Rogan is no snake oil salesman.\\n\\nHe ...</td>\n",
|
||
" <td>1319464245</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5149</th>\n",
|
||
" <td>Reading further, Sagan does seem to agree with...</td>\n",
|
||
" <td>1322126150</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5150</th>\n",
|
||
" <td>Notice that they never invoke god, or any othe...</td>\n",
|
||
" <td>1307679295</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5151</th>\n",
|
||
" <td>They might co-ordinate an anniversary attack o...</td>\n",
|
||
" <td>1342409261</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>5152 rows × 2 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" 0 1\n",
|
||
"0 Gentleman, I believe we can agree that this is... 1304170330\n",
|
||
"1 The problem is that it will just turn it r/nos... 1353763204\n",
|
||
"2 Well, according to some Christian apologists, ... 1336314173\n",
|
||
"3 Don't know if this is what you are looking for... 1348860314\n",
|
||
"4 I respect what you're saying completely. I jus... 1341285952\n",
|
||
"... ... ...\n",
|
||
"5147 GAMBIT 1326441107\n",
|
||
"5148 >Joe Rogan is no snake oil salesman.\\n\\nHe ... 1319464245\n",
|
||
"5149 Reading further, Sagan does seem to agree with... 1322126150\n",
|
||
"5150 Notice that they never invoke god, or any othe... 1307679295\n",
|
||
"5151 They might co-ordinate an anniversary attack o... 1342409261\n",
|
||
"\n",
|
||
"[5152 rows x 2 columns]"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"with open('test-A/in.tsv', 'r', encoding='utf8') as file:\n",
|
||
" x_test = pd.read_csv(file, sep='\\t', header=None)\n",
|
||
"x_test"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "dba17668-971f-47f8-99ce-fc840b5cb74a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class NeuralNetworkModel(torch.nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super(NeuralNetworkModel, self).__init__()\n",
|
||
" self.l01 = torch.nn.Linear(300, 300)\n",
|
||
" self.l02 = torch.nn.Linear(300, 1)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" x = self.l01(x)\n",
|
||
" x = torch.relu(x)\n",
|
||
" x = self.l02(x)\n",
|
||
" x = torch.sigmoid(x)\n",
|
||
" return x\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "1a275c1d-75bc-4290-9332-56396d16a0f2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"x_train = x_train[0].str.lower()\n",
|
||
"y_train = y_train[0]\n",
|
||
"x_dev = x_dev[0].str.lower()\n",
|
||
"x_test = x_test[0].str.lower()\n",
|
||
"\n",
|
||
"x_train = [word_tokenize(x) for x in x_train]\n",
|
||
"x_dev = [word_tokenize(x) for x in x_dev]\n",
|
||
"x_test = [word_tokenize(x) for x in x_test]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "031a3670-3be7-4146-97b4-0dacd4f9ae58",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from gensim.test.utils import common_texts\n",
|
||
"from gensim.models import Word2Vec\n",
|
||
"\n",
|
||
"word2vec = gensim.downloader.load('word2vec-google-news-300')\n",
|
||
"x_train = [np.mean([word2vec[word] for word in content if word in word2vec] or [np.zeros(300)], axis=0) for content in x_train]\n",
|
||
"x_dev = [np.mean([word2vec[word] for word in content if word in word2vec] or [np.zeros(300)], axis=0) for content in x_dev]\n",
|
||
"x_test = [np.mean([word2vec[word] for word in content if word in word2vec] or [np.zeros(300)], axis=0) for content in x_test]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "b7defd18-e281-4cf6-9941-cee560749677",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"C:\\Users\\korne\\AppData\\Local\\Temp\\ipykernel_22024\\3484013121.py:10: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:210.)\n",
|
||
" X = torch.tensor(X)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model = NeuralNetworkModel()\n",
|
||
"BATCH_SIZE = 5\n",
|
||
"criterion = torch.nn.BCELoss()\n",
|
||
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
|
||
"\n",
|
||
"for epoch in range(BATCH_SIZE):\n",
|
||
" model.train()\n",
|
||
" for i in range(0, y_train.shape[0], BATCH_SIZE):\n",
|
||
" X = x_train[i:i + BATCH_SIZE]\n",
|
||
" X = torch.tensor(X)\n",
|
||
" y = y_train[i:i + BATCH_SIZE]\n",
|
||
" y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1, 1)\n",
|
||
" optimizer.zero_grad()\n",
|
||
" outputs = model(X.float())\n",
|
||
" loss = criterion(outputs, y)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "92c69ddd-fe58-477f-b2c2-06324a983bcc",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_dev = []\n",
|
||
"y_test = []\n",
|
||
"model.eval()\n",
|
||
"\n",
|
||
"with torch.no_grad():\n",
|
||
" for i in range(0, len(x_dev), BATCH_SIZE):\n",
|
||
" X = x_dev[i:i + BATCH_SIZE]\n",
|
||
" X = torch.tensor(X)\n",
|
||
" outputs = model(X.float())\n",
|
||
" prediction = (outputs > 0.5)\n",
|
||
" y_dev += prediction.tolist()\n",
|
||
"\n",
|
||
" for i in range(0, len(x_test), BATCH_SIZE):\n",
|
||
" X = x_test[i:i + BATCH_SIZE]\n",
|
||
" X = torch.tensor(X)\n",
|
||
" outputs = model(X.float())\n",
|
||
" y = (outputs >= 0.5)\n",
|
||
" y_test += prediction.tolist()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "caff921c-d0ab-4fce-a17f-6610266b404d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_dev = np.asarray(y_dev, dtype=np.int32)\n",
|
||
"y_test = np.asarray(y_test, dtype=np.int32)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "73076eb2-810f-4f85-aa3f-05ee884c413b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"with open('./dev-0/out.tsv', 'wt') as file:\n",
|
||
" for r in y_dev:\n",
|
||
" file.write(str(r) + '\\n') "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "ddda251c-cafa-40f8-a020-48310a9f23b6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"with open('./test-A/out.tsv', 'wt') as file:\n",
|
||
" for r in y_test:\n",
|
||
" file.write(str(r) + '\\n') "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "5730562a-0200-4c8f-8f73-992fa2b36133",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[NbConvertApp] Converting notebook run.ipynb to script\n",
|
||
"[NbConvertApp] Writing 3816 bytes to run.py\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"!jupyter nbconvert --to script run.ipynb"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "07a09298-204c-4905-90a8-5dcb87877368",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.12"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|