849 lines
37 KiB
Plaintext
849 lines
37 KiB
Plaintext
|
{
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 0,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"provenance": [],
|
|||
|
"gpuType": "V100"
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"name": "python3",
|
|||
|
"display_name": "Python 3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"name": "python"
|
|||
|
},
|
|||
|
"accelerator": "GPU",
|
|||
|
"gpuClass": "standard"
|
|||
|
},
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {
|
|||
|
"id": "LYTCs2MjhLuZ"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import torch\n",
|
|||
|
"from torch import nn\n",
|
|||
|
"\n",
|
|||
|
"torch.cuda.empty_cache()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"from google.colab import drive\n",
|
|||
|
"drive.mount('/content/drive')"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "unzqnLN9isoP",
|
|||
|
"outputId": "b44d1087-3600-4fc2-9998-cf6520e9e743"
|
|||
|
},
|
|||
|
"execution_count": 2,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"name": "stdout",
|
|||
|
"text": [
|
|||
|
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"%cd drive/MyDrive/moj7"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "hRG7HFaFi6aV",
|
|||
|
"outputId": "c498eecc-d661-4842-8ae5-91819e38b7cd"
|
|||
|
},
|
|||
|
"execution_count": 3,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"name": "stdout",
|
|||
|
"text": [
|
|||
|
"/content/drive/MyDrive/moj7\n"
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"!ls"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "T5XQ2uY5jH4U",
|
|||
|
"outputId": "1ad2d4a8-a575-4021-cbc0-3875f956f874"
|
|||
|
},
|
|||
|
"execution_count": 4,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"name": "stdout",
|
|||
|
"text": [
|
|||
|
"config.txt\t in-header.tsv\tout-header.tsv\t test-A\n",
|
|||
|
"dev-0\t\t model1.bin\tprocessed_train.txt train\n",
|
|||
|
"filename.pickle model2.bin\tsimplepredict.py train_new.txt\n"
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import regex as re\n",
|
|||
|
"import csv\n",
|
|||
|
"\n",
|
|||
|
"def clean_text(text):\n",
|
|||
|
" text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')\n",
|
|||
|
" text = re.sub(r'\\p{P}', '', text)\n",
|
|||
|
" text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n",
|
|||
|
"\n",
|
|||
|
" return text"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "6_8pn-p3hO2a"
|
|||
|
},
|
|||
|
"execution_count": 5,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"\n",
|
|||
|
"train_data = train_data[[6, 7]]\n",
|
|||
|
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
|
|||
|
"\n",
|
|||
|
"train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n",
|
|||
|
"train_data = train_data[['text']]\n",
|
|||
|
"\n",
|
|||
|
"with open('processed_train.txt', 'w', encoding='utf-8') as file:\n",
|
|||
|
" for _, row in train_data.iterrows():\n",
|
|||
|
" text = clean_text(str(row['text']))\n",
|
|||
|
" file.write(text + '\\n')"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "3WU8aYOghO4x",
|
|||
|
"outputId": "54b2531c-541d-4b8d-92f9-20bcd52d843f"
|
|||
|
},
|
|||
|
"execution_count": 6,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"name": "stderr",
|
|||
|
"text": [
|
|||
|
"<ipython-input-6-c2ca5c6b11cc>:1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"<ipython-input-6-c2ca5c6b11cc>:1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"<ipython-input-6-c2ca5c6b11cc>:2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"<ipython-input-6-c2ca5c6b11cc>:2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"import itertools\n",
|
|||
|
"import lzma\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import regex as re\n",
|
|||
|
"import torch\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from torch import nn\n",
|
|||
|
"from torch.utils.data import IterableDataset, DataLoader\n",
|
|||
|
"import csv\n",
|
|||
|
"from itertools import islice, chain\n",
|
|||
|
"from torchtext.vocab import build_vocab_from_iterator"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "tw9MDSzpisGN"
|
|||
|
},
|
|||
|
"execution_count": 7,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [],
|
|||
|
"metadata": {
|
|||
|
"id": "M-aI-gI7hO7V"
|
|||
|
},
|
|||
|
"execution_count": 7,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"device='cuda'"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "tVHkGBzLhO9u"
|
|||
|
},
|
|||
|
"execution_count": 8,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"train_data = train_data[[6, 7]]\n",
|
|||
|
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
|
|||
|
"train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n",
|
|||
|
"train_data = train_data[['text']]"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "ph3ibZmlhPAI",
|
|||
|
"outputId": "c4524bf5-d7f9-4c7f-ed89-7f6451725ea2"
|
|||
|
},
|
|||
|
"execution_count": 9,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"name": "stderr",
|
|||
|
"text": [
|
|||
|
"<ipython-input-9-28a7685109f8>:1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"<ipython-input-9-28a7685109f8>:1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"<ipython-input-9-28a7685109f8>:2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
|||
|
"<ipython-input-9-28a7685109f8>:2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"train_data"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 424
|
|||
|
},
|
|||
|
"id": "uASpVNQXhPC1",
|
|||
|
"outputId": "45126fc2-5ff5-4be3-f114-c5fa7da9189c"
|
|||
|
},
|
|||
|
"execution_count": 10,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "execute_result",
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
" text\n",
|
|||
|
"0 came fiom the last place to this\\nplace, and t...\n",
|
|||
|
"1 MB. BOOT'S POLITICAL OBEED\\nAttempt to imagine...\n",
|
|||
|
"2 \"Thera were in 1771 only aeventy-nine\\n*ub*erl...\n",
|
|||
|
"3 A gixnl man y nitereRtiiiv dii-clos-\\nur«s reg...\n",
|
|||
|
"4 Tin: 188UB TV THF BBABBT QABJE\\nMr. Schiffs *t...\n",
|
|||
|
"... ...\n",
|
|||
|
"432017 Sam Clendenin bad a fancy for Ui«\\nscience of ...\n",
|
|||
|
"432018 Wita.htt halting the party ware dilven to the ...\n",
|
|||
|
"432019 It was the last thing that either of\\nthem exp...\n",
|
|||
|
"432020 settlement with the department.\\nIt is also sh...\n",
|
|||
|
"432021 Flour quotations—low extras at 1 R0®2 50;\\ncit...\n",
|
|||
|
"\n",
|
|||
|
"[432022 rows x 1 columns]"
|
|||
|
],
|
|||
|
"text/html": [
|
|||
|
"\n",
|
|||
|
" <div id=\"df-dcb2e1c9-80ba-4d3d-adb6-9daac97738db\">\n",
|
|||
|
" <div class=\"colab-df-container\">\n",
|
|||
|
" <div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>text</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>came fiom the last place to this\\nplace, and t...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>MB. BOOT'S POLITICAL OBEED\\nAttempt to imagine...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>\"Thera were in 1771 only aeventy-nine\\n*ub*erl...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>A gixnl man y nitereRtiiiv dii-clos-\\nur«s reg...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>Tin: 188UB TV THF BBABBT QABJE\\nMr. Schiffs *t...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>432017</th>\n",
|
|||
|
" <td>Sam Clendenin bad a fancy for Ui«\\nscience of ...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>432018</th>\n",
|
|||
|
" <td>Wita.htt halting the party ware dilven to the ...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>432019</th>\n",
|
|||
|
" <td>It was the last thing that either of\\nthem exp...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>432020</th>\n",
|
|||
|
" <td>settlement with the department.\\nIt is also sh...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>432021</th>\n",
|
|||
|
" <td>Flour quotations—low extras at 1 R0®2 50;\\ncit...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>432022 rows × 1 columns</p>\n",
|
|||
|
"</div>\n",
|
|||
|
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-dcb2e1c9-80ba-4d3d-adb6-9daac97738db')\"\n",
|
|||
|
" title=\"Convert this dataframe to an interactive table.\"\n",
|
|||
|
" style=\"display:none;\">\n",
|
|||
|
" \n",
|
|||
|
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
|
|||
|
" width=\"24px\">\n",
|
|||
|
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
|
|||
|
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
|
|||
|
" </svg>\n",
|
|||
|
" </button>\n",
|
|||
|
" \n",
|
|||
|
" <style>\n",
|
|||
|
" .colab-df-container {\n",
|
|||
|
" display:flex;\n",
|
|||
|
" flex-wrap:wrap;\n",
|
|||
|
" gap: 12px;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .colab-df-convert {\n",
|
|||
|
" background-color: #E8F0FE;\n",
|
|||
|
" border: none;\n",
|
|||
|
" border-radius: 50%;\n",
|
|||
|
" cursor: pointer;\n",
|
|||
|
" display: none;\n",
|
|||
|
" fill: #1967D2;\n",
|
|||
|
" height: 32px;\n",
|
|||
|
" padding: 0 0 0 0;\n",
|
|||
|
" width: 32px;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .colab-df-convert:hover {\n",
|
|||
|
" background-color: #E2EBFA;\n",
|
|||
|
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
|
|||
|
" fill: #174EA6;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" [theme=dark] .colab-df-convert {\n",
|
|||
|
" background-color: #3B4455;\n",
|
|||
|
" fill: #D2E3FC;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" [theme=dark] .colab-df-convert:hover {\n",
|
|||
|
" background-color: #434B5C;\n",
|
|||
|
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
|
|||
|
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
|
|||
|
" fill: #FFFFFF;\n",
|
|||
|
" }\n",
|
|||
|
" </style>\n",
|
|||
|
"\n",
|
|||
|
" <script>\n",
|
|||
|
" const buttonEl =\n",
|
|||
|
" document.querySelector('#df-dcb2e1c9-80ba-4d3d-adb6-9daac97738db button.colab-df-convert');\n",
|
|||
|
" buttonEl.style.display =\n",
|
|||
|
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
|
|||
|
"\n",
|
|||
|
" async function convertToInteractive(key) {\n",
|
|||
|
" const element = document.querySelector('#df-dcb2e1c9-80ba-4d3d-adb6-9daac97738db');\n",
|
|||
|
" const dataTable =\n",
|
|||
|
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
|
|||
|
" [key], {});\n",
|
|||
|
" if (!dataTable) return;\n",
|
|||
|
"\n",
|
|||
|
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
|
|||
|
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
|
|||
|
" + ' to learn more about interactive tables.';\n",
|
|||
|
" element.innerHTML = '';\n",
|
|||
|
" dataTable['output_type'] = 'display_data';\n",
|
|||
|
" await google.colab.output.renderOutput(dataTable, element);\n",
|
|||
|
" const docLink = document.createElement('div');\n",
|
|||
|
" docLink.innerHTML = docLinkHtml;\n",
|
|||
|
" element.appendChild(docLink);\n",
|
|||
|
" }\n",
|
|||
|
" </script>\n",
|
|||
|
" </div>\n",
|
|||
|
" </div>\n",
|
|||
|
" "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"execution_count": 10
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"with open('train_new.txt', 'w', encoding='utf-8') as file:\n",
|
|||
|
" for _, row in train_data.iterrows():\n",
|
|||
|
" text = clean_text(str(row['text']))\n",
|
|||
|
" file.write(text + '\\n')\n",
|
|||
|
"\n"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "_28Jf3EyhPFu"
|
|||
|
},
|
|||
|
"execution_count": 11,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"class SimpleTrigramNeuralLanguageModel(nn.Module):\n",
|
|||
|
" def __init__(self, vocabulary_size, embedding_size, hidden_size):\n",
|
|||
|
" super(SimpleTrigramNeuralLanguageModel, self).__init__()\n",
|
|||
|
" self.embedding = nn.Embedding(vocabulary_size * 2, embedding_size)\n",
|
|||
|
" self.linear1 = nn.Linear(embedding_size, hidden_size)\n",
|
|||
|
" self.linear2 = nn.Linear(hidden_size, vocabulary_size * 2)\n",
|
|||
|
"\n",
|
|||
|
" def forward(self, x):\n",
|
|||
|
" x = self.embedding(x)\n",
|
|||
|
" x = self.linear1(x)\n",
|
|||
|
" x = self.linear2(x)\n",
|
|||
|
" x = torch.softmax(x, dim=1)\n",
|
|||
|
" return x"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "HdaLacIRhPIS"
|
|||
|
},
|
|||
|
"execution_count": 12,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"vocab_size = 38000\n",
|
|||
|
"embed_size = 300\n",
|
|||
|
"hidden_size = 256"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "k-qcQuVYhPK7"
|
|||
|
},
|
|||
|
"execution_count": 13,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def words_line(line):\n",
|
|||
|
" line = line.rstrip()\n",
|
|||
|
" yield '<s>'\n",
|
|||
|
" for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n",
|
|||
|
" yield m.group(0).lower()\n",
|
|||
|
" yield '</s>'\n",
|
|||
|
"\n",
|
|||
|
"def file_words(file_name):\n",
|
|||
|
" with open(file_name, 'r', encoding='utf-8') as fh:\n",
|
|||
|
" for line in fh:\n",
|
|||
|
" yield words_line(line)"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "w9yhw6n0hPNV"
|
|||
|
},
|
|||
|
"execution_count": 14,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def iterator_look(gen):\n",
|
|||
|
" first_prev = None\n",
|
|||
|
" sec_prev = None\n",
|
|||
|
" for item in gen:\n",
|
|||
|
" if first_prev and sec_prev:\n",
|
|||
|
" yield (sec_prev+ first_prev, item)\n",
|
|||
|
" sec_prev = first_prev\n",
|
|||
|
" first_prev = item"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "suwoA5QFhPP9"
|
|||
|
},
|
|||
|
"execution_count": 15,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"class Trigrams(IterableDataset):\n",
|
|||
|
" def __init__(self, text_file, vocabulary_size):\n",
|
|||
|
" self.vocab = build_vocab_from_iterator(\n",
|
|||
|
" file_words(text_file),\n",
|
|||
|
" max_tokens = vocabulary_size,\n",
|
|||
|
" specials = ['<unk>']\n",
|
|||
|
" )\n",
|
|||
|
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
|||
|
" self.vocabulary_size = vocabulary_size\n",
|
|||
|
" self.text_file = text_file\n",
|
|||
|
"\n",
|
|||
|
" def __iter__(self):\n",
|
|||
|
" return iterator_look((self.vocab[t] for t in chain.from_iterable(file_words(self.text_file))))"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "9ZZllfdxhPSd"
|
|||
|
},
|
|||
|
"execution_count": 16,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"def training(xx):\n",
|
|||
|
" train_dataset_new = Trigrams('train_new.txt', vocab_size)\n",
|
|||
|
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
|||
|
" optimizer = torch.optim.Adam(model.parameters())\n",
|
|||
|
" criterion = torch.nn.NLLLoss()\n",
|
|||
|
" data = DataLoader(train_dataset_new, batch_size=800)\n",
|
|||
|
" step = 0\n",
|
|||
|
" for epoch in range(1):\n",
|
|||
|
" model.train()\n",
|
|||
|
" for x, y in data:\n",
|
|||
|
" x = x.to(device)\n",
|
|||
|
" y = y.to(device)\n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" outputs = model(x)\n",
|
|||
|
" loss = criterion(torch.log(outputs), y)\n",
|
|||
|
" if step % 100 == 0:\n",
|
|||
|
" print(step, loss)\n",
|
|||
|
" step += 1\n",
|
|||
|
" loss.backward()\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" torch.save(model.state_dict(), 'model2.bin')"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "QjZ9Rl7-kUYC"
|
|||
|
},
|
|||
|
"execution_count": 17,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"training(xx=0.0001)"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "HOSUqszakUac",
|
|||
|
"outputId": "ec9f6d23-3014-4787-e2d7-22520974a7df"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"name": "stdout",
|
|||
|
"text": [
|
|||
|
"0 tensor(11.2670, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"100 tensor(8.0867, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"200 tensor(6.8976, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"300 tensor(6.6515, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"400 tensor(6.6224, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"500 tensor(6.7443, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"600 tensor(6.7064, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"700 tensor(6.8224, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"800 tensor(6.8516, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"900 tensor(6.6103, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1000 tensor(6.5455, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1100 tensor(6.8369, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1200 tensor(6.5587, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1300 tensor(6.2804, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1400 tensor(6.5476, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1500 tensor(6.7563, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1600 tensor(6.5324, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1700 tensor(6.6478, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1800 tensor(6.4025, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"1900 tensor(6.4470, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2000 tensor(6.8199, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2100 tensor(6.2291, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2200 tensor(6.4627, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2300 tensor(6.5401, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2400 tensor(6.4382, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2500 tensor(6.4881, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2600 tensor(6.2683, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2700 tensor(6.5393, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2800 tensor(6.8077, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"2900 tensor(6.6460, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3000 tensor(6.4482, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3100 tensor(6.6288, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3200 tensor(6.4752, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3300 tensor(6.3716, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3400 tensor(6.4713, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3500 tensor(6.4488, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3600 tensor(6.5300, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3700 tensor(6.3824, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3800 tensor(6.6311, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"3900 tensor(6.3778, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4000 tensor(6.4160, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4100 tensor(6.5501, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4200 tensor(6.6891, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4300 tensor(6.4745, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4400 tensor(6.7940, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4500 tensor(6.2111, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4600 tensor(6.7691, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4700 tensor(6.2466, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4800 tensor(6.5852, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"4900 tensor(6.1048, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5000 tensor(6.5077, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5100 tensor(6.6974, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5200 tensor(6.4872, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5300 tensor(6.4792, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5400 tensor(6.4319, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5500 tensor(6.4370, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5600 tensor(6.5948, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5700 tensor(6.5184, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5800 tensor(6.4193, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"5900 tensor(6.4801, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6000 tensor(6.4735, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6100 tensor(6.4440, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6200 tensor(6.3385, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6300 tensor(6.2252, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6400 tensor(6.2866, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6500 tensor(6.8166, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6600 tensor(6.4074, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6700 tensor(6.6818, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6800 tensor(5.9832, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"6900 tensor(6.1267, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7000 tensor(6.6872, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7100 tensor(6.4554, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7200 tensor(6.5397, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7300 tensor(6.3267, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7400 tensor(6.4830, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7500 tensor(6.5805, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7600 tensor(6.1212, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7700 tensor(6.2900, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7800 tensor(6.1379, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"7900 tensor(6.1837, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8000 tensor(6.5634, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8100 tensor(6.5012, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8200 tensor(6.3135, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8300 tensor(6.6141, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8400 tensor(6.4679, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8500 tensor(6.2488, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8600 tensor(6.3222, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8700 tensor(6.4057, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8800 tensor(6.2209, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"8900 tensor(6.6274, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9000 tensor(6.4992, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9100 tensor(6.5748, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9200 tensor(6.2457, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9300 tensor(6.4364, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9400 tensor(6.4908, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9500 tensor(6.5462, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9600 tensor(6.3248, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9700 tensor(6.3758, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9800 tensor(6.1925, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"9900 tensor(6.5854, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10000 tensor(6.5270, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10100 tensor(6.3718, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10200 tensor(6.6314, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10300 tensor(6.3025, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10400 tensor(6.2880, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10500 tensor(6.6817, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10600 tensor(6.4151, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10700 tensor(6.5276, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10800 tensor(6.6714, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"10900 tensor(6.4049, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11000 tensor(6.2844, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11100 tensor(6.3522, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11200 tensor(6.5579, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11300 tensor(6.6415, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11400 tensor(6.2489, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11500 tensor(6.1745, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11600 tensor(6.5829, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11700 tensor(6.4514, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11800 tensor(6.4100, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"11900 tensor(6.2816, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12000 tensor(6.4974, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12100 tensor(6.3546, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12200 tensor(6.4354, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12300 tensor(6.2498, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12400 tensor(6.2456, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12500 tensor(6.2744, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12600 tensor(6.3540, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12700 tensor(6.4590, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12800 tensor(6.3227, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"12900 tensor(6.2072, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"13000 tensor(6.1667, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|||
|
"13100 tensor(6.4865, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
|||
|
"model.load_state_dict(torch.load('model2.bin'))\n",
|
|||
|
"model.eval()\n",
|
|||
|
"train_dataset_new = Trigrams('train_new.txt', vocab_size)\n",
|
|||
|
"\n",
|
|||
|
"def predict_words(words):\n",
|
|||
|
" ixs = torch.tensor(train_dataset_new.vocab.forward(['with'])).to(device)\n",
|
|||
|
" predictions = model(ixs)\n",
|
|||
|
" total_prob = 0.0\n",
|
|||
|
" prediction = ''\n",
|
|||
|
" top = torch.topk(predictions[0], 30)\n",
|
|||
|
" top_indices = top.indices.tolist()\n",
|
|||
|
" top_probs = top.values.tolist()\n",
|
|||
|
" top_words = train_dataset_new.vocab.lookup_tokens(top_indices)\n",
|
|||
|
" top_preds = list(zip(top_words, top_indices, top_probs))\n",
|
|||
|
"\n",
|
|||
|
" for word, _, prob in top_preds:\n",
|
|||
|
" if word != '<unk>':\n",
|
|||
|
" prediction += f'{word}:{prob} '\n",
|
|||
|
" total_prob += prob\n",
|
|||
|
" prediction += f':{1 - total_prob}'\n",
|
|||
|
" return prediction"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "5K9YlprQkUc8"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
|||
|
"model.load_state_dict(torch.load('model2.bin'))\n",
|
|||
|
"model.eval() "
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "MgaRdbD8kUfd"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"with lzma.open(f'dev-0/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
|||
|
" with open(f'dev-0/out-HIDDEN-SIZE={hidden_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
|||
|
" for line in fid:\n",
|
|||
|
" separated = line.split('\\t')\n",
|
|||
|
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
|
|||
|
" output_line = predict_words(prefix)\n",
|
|||
|
" f.write(output_line + '\\n')"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "MoL-FV4rkgZB"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"with lzma.open(f'test-A/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
|||
|
" with open(f'test-A/out-HIDDEN-SIZE={hidden_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
|||
|
" for line in fid:\n",
|
|||
|
" separated = line.split('\\t')\n",
|
|||
|
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
|
|||
|
" output_line = predict_words(prefix)\n",
|
|||
|
" f.write(output_line + '\\n')"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "jHlOHc8Hkgbg"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [
|
|||
|
"torch.save(model.state_dict(), 'model2.bin')"
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"id": "CcX31HX1kgd4"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"source": [],
|
|||
|
"metadata": {
|
|||
|
"id": "DhbNd_O8koQv"
|
|||
|
},
|
|||
|
"execution_count": null,
|
|||
|
"outputs": []
|
|||
|
}
|
|||
|
]
|
|||
|
}
|