{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "V100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "LYTCs2MjhLuZ" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "unzqnLN9isoP", "outputId": "b44d1087-3600-4fc2-9998-cf6520e9e743" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ] }, { "cell_type": "code", "source": [ "%cd drive/MyDrive/moj7" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hRG7HFaFi6aV", "outputId": "c498eecc-d661-4842-8ae5-91819e38b7cd" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "/content/drive/MyDrive/moj7\n" ] } ] }, { "cell_type": "code", "source": [ "!ls" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "T5XQ2uY5jH4U", "outputId": "1ad2d4a8-a575-4021-cbc0-3875f956f874" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "config.txt\t in-header.tsv\tout-header.tsv\t test-A\n", "dev-0\t\t model1.bin\tprocessed_train.txt train\n", "filename.pickle model2.bin\tsimplepredict.py train_new.txt\n" ] } ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import regex as re\n", "import csv\n", "\n", "def clean_text(text):\n", " text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')\n", " text = re.sub(r'\\p{P}', '', text)\n", " text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n", "\n", " return text" ], "metadata": { "id": "6_8pn-p3hO2a" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "\n", "train_data = train_data[[6, 7]]\n", "train_data = pd.concat([train_data, train_labels], axis=1)\n", "\n", "train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n", "train_data = train_data[['text']]\n", "\n", "with open('processed_train.txt', 'w', encoding='utf-8') as file:\n", " for _, row in train_data.iterrows():\n", " text = clean_text(str(row['text']))\n", " file.write(text + '\\n')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3WU8aYOghO4x", "outputId": "54b2531c-541d-4b8d-92f9-20bcd52d843f" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n" ] } ] }, { "cell_type": "code", "source": [ "import itertools\n", "import lzma\n", "import numpy as np\n", "import regex as re\n", "import torch\n", "import pandas as pd\n", "from torch import nn\n", "from torch.utils.data import IterableDataset, DataLoader\n", "import csv\n", "from itertools import islice, chain\n", "from torchtext.vocab import build_vocab_from_iterator" ], "metadata": { "id": "tw9MDSzpisGN" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "M-aI-gI7hO7V" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "device='cuda'" ], "metadata": { "id": "tVHkGBzLhO9u" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "train_data = train_data[[6, 7]]\n", "train_data = pd.concat([train_data, train_labels], axis=1)\n", "train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n", "train_data = train_data[['text']]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ph3ibZmlhPAI", "outputId": "c4524bf5-d7f9-4c7f-ed89-7f6451725ea2" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":1: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":1: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":2: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", ":2: FutureWarning: The warn_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n", "\n", "\n", " train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n" ] } ] }, { "cell_type": "code", "source": [ "train_data" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "uASpVNQXhPC1", "outputId": "45126fc2-5ff5-4be3-f114-c5fa7da9189c" }, "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text\n", "0 came fiom the last place to this\\nplace, and t...\n", "1 MB. BOOT'S POLITICAL OBEED\\nAttempt to imagine...\n", "2 \"Thera were in 1771 only aeventy-nine\\n*ub*erl...\n", "3 A gixnl man y nitereRtiiiv dii-clos-\\nur«s reg...\n", "4 Tin: 188UB TV THF BBABBT QABJE\\nMr. Schiffs *t...\n", "... ...\n", "432017 Sam Clendenin bad a fancy for Ui«\\nscience of ...\n", "432018 Wita.htt halting the party ware dilven to the ...\n", "432019 It was the last thing that either of\\nthem exp...\n", "432020 settlement with the department.\\nIt is also sh...\n", "432021 Flour quotations—low extras at 1 R0®2 50;\\ncit...\n", "\n", "[432022 rows x 1 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
text
0came fiom the last place to this\\nplace, and t...
1MB. BOOT'S POLITICAL OBEED\\nAttempt to imagine...
2\"Thera were in 1771 only aeventy-nine\\n*ub*erl...
3A gixnl man y nitereRtiiiv dii-clos-\\nur«s reg...
4Tin: 188UB TV THF BBABBT QABJE\\nMr. Schiffs *t...
......
432017Sam Clendenin bad a fancy for Ui«\\nscience of ...
432018Wita.htt halting the party ware dilven to the ...
432019It was the last thing that either of\\nthem exp...
432020settlement with the department.\\nIt is also sh...
432021Flour quotations—low extras at 1 R0®2 50;\\ncit...
\n", "

432022 rows × 1 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "with open('train_new.txt', 'w', encoding='utf-8') as file:\n", " for _, row in train_data.iterrows():\n", " text = clean_text(str(row['text']))\n", " file.write(text + '\\n')\n", "\n" ], "metadata": { "id": "_28Jf3EyhPFu" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "class SimpleTrigramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size, hidden_size):\n", " super(SimpleTrigramNeuralLanguageModel, self).__init__()\n", " self.embedding = nn.Embedding(vocabulary_size * 2, embedding_size)\n", " self.linear1 = nn.Linear(embedding_size, hidden_size)\n", " self.linear2 = nn.Linear(hidden_size, vocabulary_size * 2)\n", "\n", " def forward(self, x):\n", " x = self.embedding(x)\n", " x = self.linear1(x)\n", " x = self.linear2(x)\n", " x = torch.softmax(x, dim=1)\n", " return x" ], "metadata": { "id": "HdaLacIRhPIS" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "vocab_size = 38000\n", "embed_size = 300\n", "hidden_size = 256" ], "metadata": { "id": "k-qcQuVYhPK7" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "def words_line(line):\n", " line = line.rstrip()\n", " yield ''\n", " for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n", " yield m.group(0).lower()\n", " yield ''\n", "\n", "def file_words(file_name):\n", " with open(file_name, 'r', encoding='utf-8') as fh:\n", " for line in fh:\n", " yield words_line(line)" ], "metadata": { "id": "w9yhw6n0hPNV" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "def iterator_look(gen):\n", " first_prev = None\n", " sec_prev = None\n", " for item in gen:\n", " if first_prev and sec_prev:\n", " yield (sec_prev+ first_prev, item)\n", " sec_prev = first_prev\n", " first_prev = item" ], "metadata": { "id": "suwoA5QFhPP9" }, "execution_count": 15, "outputs": [] }, { "cell_type": "code", "source": [ "class Trigrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = build_vocab_from_iterator(\n", " file_words(text_file),\n", " max_tokens = vocabulary_size,\n", " specials = ['']\n", " )\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return iterator_look((self.vocab[t] for t in chain.from_iterable(file_words(self.text_file))))" ], "metadata": { "id": "9ZZllfdxhPSd" }, "execution_count": 16, "outputs": [] }, { "cell_type": "code", "source": [ "def training(xx):\n", " train_dataset_new = Trigrams('train_new.txt', vocab_size)\n", " model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", " optimizer = torch.optim.Adam(model.parameters())\n", " criterion = torch.nn.NLLLoss()\n", " data = DataLoader(train_dataset_new, batch_size=800)\n", " step = 0\n", " for epoch in range(1):\n", " model.train()\n", " for x, y in data:\n", " x = x.to(device)\n", " y = y.to(device)\n", " optimizer.zero_grad()\n", " outputs = model(x)\n", " loss = criterion(torch.log(outputs), y)\n", " if step % 100 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()\n", " torch.save(model.state_dict(), 'model2.bin')" ], "metadata": { "id": "QjZ9Rl7-kUYC" }, "execution_count": 17, "outputs": [] }, { "cell_type": "code", "source": [ "training(xx=0.0001)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HOSUqszakUac", "outputId": "ec9f6d23-3014-4787-e2d7-22520974a7df" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0 tensor(11.2670, device='cuda:0', grad_fn=)\n", "100 tensor(8.0867, device='cuda:0', grad_fn=)\n", "200 tensor(6.8976, device='cuda:0', grad_fn=)\n", "300 tensor(6.6515, device='cuda:0', grad_fn=)\n", "400 tensor(6.6224, device='cuda:0', grad_fn=)\n", "500 tensor(6.7443, device='cuda:0', grad_fn=)\n", "600 tensor(6.7064, device='cuda:0', grad_fn=)\n", "700 tensor(6.8224, device='cuda:0', grad_fn=)\n", "800 tensor(6.8516, device='cuda:0', grad_fn=)\n", "900 tensor(6.6103, device='cuda:0', grad_fn=)\n", "1000 tensor(6.5455, device='cuda:0', grad_fn=)\n", "1100 tensor(6.8369, device='cuda:0', grad_fn=)\n", "1200 tensor(6.5587, device='cuda:0', grad_fn=)\n", "1300 tensor(6.2804, device='cuda:0', grad_fn=)\n", "1400 tensor(6.5476, device='cuda:0', grad_fn=)\n", "1500 tensor(6.7563, device='cuda:0', grad_fn=)\n", "1600 tensor(6.5324, device='cuda:0', grad_fn=)\n", "1700 tensor(6.6478, device='cuda:0', grad_fn=)\n", "1800 tensor(6.4025, device='cuda:0', grad_fn=)\n", "1900 tensor(6.4470, device='cuda:0', grad_fn=)\n", "2000 tensor(6.8199, device='cuda:0', grad_fn=)\n", "2100 tensor(6.2291, device='cuda:0', grad_fn=)\n", "2200 tensor(6.4627, device='cuda:0', grad_fn=)\n", "2300 tensor(6.5401, device='cuda:0', grad_fn=)\n", "2400 tensor(6.4382, device='cuda:0', grad_fn=)\n", "2500 tensor(6.4881, device='cuda:0', grad_fn=)\n", "2600 tensor(6.2683, device='cuda:0', grad_fn=)\n", "2700 tensor(6.5393, device='cuda:0', grad_fn=)\n", "2800 tensor(6.8077, device='cuda:0', grad_fn=)\n", "2900 tensor(6.6460, device='cuda:0', grad_fn=)\n", "3000 tensor(6.4482, device='cuda:0', grad_fn=)\n", "3100 tensor(6.6288, device='cuda:0', grad_fn=)\n", "3200 tensor(6.4752, device='cuda:0', grad_fn=)\n", "3300 tensor(6.3716, device='cuda:0', grad_fn=)\n", "3400 tensor(6.4713, device='cuda:0', grad_fn=)\n", "3500 tensor(6.4488, device='cuda:0', grad_fn=)\n", "3600 tensor(6.5300, device='cuda:0', grad_fn=)\n", "3700 tensor(6.3824, device='cuda:0', grad_fn=)\n", "3800 tensor(6.6311, device='cuda:0', grad_fn=)\n", "3900 tensor(6.3778, device='cuda:0', grad_fn=)\n", "4000 tensor(6.4160, device='cuda:0', grad_fn=)\n", "4100 tensor(6.5501, device='cuda:0', grad_fn=)\n", "4200 tensor(6.6891, device='cuda:0', grad_fn=)\n", "4300 tensor(6.4745, device='cuda:0', grad_fn=)\n", "4400 tensor(6.7940, device='cuda:0', grad_fn=)\n", "4500 tensor(6.2111, device='cuda:0', grad_fn=)\n", "4600 tensor(6.7691, device='cuda:0', grad_fn=)\n", "4700 tensor(6.2466, device='cuda:0', grad_fn=)\n", "4800 tensor(6.5852, device='cuda:0', grad_fn=)\n", "4900 tensor(6.1048, device='cuda:0', grad_fn=)\n", "5000 tensor(6.5077, device='cuda:0', grad_fn=)\n", "5100 tensor(6.6974, device='cuda:0', grad_fn=)\n", "5200 tensor(6.4872, device='cuda:0', grad_fn=)\n", "5300 tensor(6.4792, device='cuda:0', grad_fn=)\n", "5400 tensor(6.4319, device='cuda:0', grad_fn=)\n", "5500 tensor(6.4370, device='cuda:0', grad_fn=)\n", "5600 tensor(6.5948, device='cuda:0', grad_fn=)\n", "5700 tensor(6.5184, device='cuda:0', grad_fn=)\n", "5800 tensor(6.4193, device='cuda:0', grad_fn=)\n", "5900 tensor(6.4801, device='cuda:0', grad_fn=)\n", "6000 tensor(6.4735, device='cuda:0', grad_fn=)\n", "6100 tensor(6.4440, device='cuda:0', grad_fn=)\n", "6200 tensor(6.3385, device='cuda:0', grad_fn=)\n", "6300 tensor(6.2252, device='cuda:0', grad_fn=)\n", "6400 tensor(6.2866, device='cuda:0', grad_fn=)\n", "6500 tensor(6.8166, device='cuda:0', grad_fn=)\n", "6600 tensor(6.4074, device='cuda:0', grad_fn=)\n", "6700 tensor(6.6818, device='cuda:0', grad_fn=)\n", "6800 tensor(5.9832, device='cuda:0', grad_fn=)\n", "6900 tensor(6.1267, device='cuda:0', grad_fn=)\n", "7000 tensor(6.6872, device='cuda:0', grad_fn=)\n", "7100 tensor(6.4554, device='cuda:0', grad_fn=)\n", "7200 tensor(6.5397, device='cuda:0', grad_fn=)\n", "7300 tensor(6.3267, device='cuda:0', grad_fn=)\n", "7400 tensor(6.4830, device='cuda:0', grad_fn=)\n", "7500 tensor(6.5805, device='cuda:0', grad_fn=)\n", "7600 tensor(6.1212, device='cuda:0', grad_fn=)\n", "7700 tensor(6.2900, device='cuda:0', grad_fn=)\n", "7800 tensor(6.1379, device='cuda:0', grad_fn=)\n", "7900 tensor(6.1837, device='cuda:0', grad_fn=)\n", "8000 tensor(6.5634, device='cuda:0', grad_fn=)\n", "8100 tensor(6.5012, device='cuda:0', grad_fn=)\n", "8200 tensor(6.3135, device='cuda:0', grad_fn=)\n", "8300 tensor(6.6141, device='cuda:0', grad_fn=)\n", "8400 tensor(6.4679, device='cuda:0', grad_fn=)\n", "8500 tensor(6.2488, device='cuda:0', grad_fn=)\n", "8600 tensor(6.3222, device='cuda:0', grad_fn=)\n", "8700 tensor(6.4057, device='cuda:0', grad_fn=)\n", "8800 tensor(6.2209, device='cuda:0', grad_fn=)\n", "8900 tensor(6.6274, device='cuda:0', grad_fn=)\n", "9000 tensor(6.4992, device='cuda:0', grad_fn=)\n", "9100 tensor(6.5748, device='cuda:0', grad_fn=)\n", "9200 tensor(6.2457, device='cuda:0', grad_fn=)\n", "9300 tensor(6.4364, device='cuda:0', grad_fn=)\n", "9400 tensor(6.4908, device='cuda:0', grad_fn=)\n", "9500 tensor(6.5462, device='cuda:0', grad_fn=)\n", "9600 tensor(6.3248, device='cuda:0', grad_fn=)\n", "9700 tensor(6.3758, device='cuda:0', grad_fn=)\n", "9800 tensor(6.1925, device='cuda:0', grad_fn=)\n", "9900 tensor(6.5854, device='cuda:0', grad_fn=)\n", "10000 tensor(6.5270, device='cuda:0', grad_fn=)\n", "10100 tensor(6.3718, device='cuda:0', grad_fn=)\n", "10200 tensor(6.6314, device='cuda:0', grad_fn=)\n", "10300 tensor(6.3025, device='cuda:0', grad_fn=)\n", "10400 tensor(6.2880, device='cuda:0', grad_fn=)\n", "10500 tensor(6.6817, device='cuda:0', grad_fn=)\n", "10600 tensor(6.4151, device='cuda:0', grad_fn=)\n", "10700 tensor(6.5276, device='cuda:0', grad_fn=)\n", "10800 tensor(6.6714, device='cuda:0', grad_fn=)\n", "10900 tensor(6.4049, device='cuda:0', grad_fn=)\n", "11000 tensor(6.2844, device='cuda:0', grad_fn=)\n", "11100 tensor(6.3522, device='cuda:0', grad_fn=)\n", "11200 tensor(6.5579, device='cuda:0', grad_fn=)\n", "11300 tensor(6.6415, device='cuda:0', grad_fn=)\n", "11400 tensor(6.2489, device='cuda:0', grad_fn=)\n", "11500 tensor(6.1745, device='cuda:0', grad_fn=)\n", "11600 tensor(6.5829, device='cuda:0', grad_fn=)\n", "11700 tensor(6.4514, device='cuda:0', grad_fn=)\n", "11800 tensor(6.4100, device='cuda:0', grad_fn=)\n", "11900 tensor(6.2816, device='cuda:0', grad_fn=)\n", "12000 tensor(6.4974, device='cuda:0', grad_fn=)\n", "12100 tensor(6.3546, device='cuda:0', grad_fn=)\n", "12200 tensor(6.4354, device='cuda:0', grad_fn=)\n", "12300 tensor(6.2498, device='cuda:0', grad_fn=)\n", "12400 tensor(6.2456, device='cuda:0', grad_fn=)\n", "12500 tensor(6.2744, device='cuda:0', grad_fn=)\n", "12600 tensor(6.3540, device='cuda:0', grad_fn=)\n", "12700 tensor(6.4590, device='cuda:0', grad_fn=)\n", "12800 tensor(6.3227, device='cuda:0', grad_fn=)\n", "12900 tensor(6.2072, device='cuda:0', grad_fn=)\n", "13000 tensor(6.1667, device='cuda:0', grad_fn=)\n", "13100 tensor(6.4865, device='cuda:0', grad_fn=)\n" ] } ] }, { "cell_type": "code", "source": [ "model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", "model.load_state_dict(torch.load('model2.bin'))\n", "model.eval()\n", "train_dataset_new = Trigrams('train_new.txt', vocab_size)\n", "\n", "def predict_words(words):\n", " ixs = torch.tensor(train_dataset_new.vocab.forward(['with'])).to(device)\n", " predictions = model(ixs)\n", " total_prob = 0.0\n", " prediction = ''\n", " top = torch.topk(predictions[0], 30)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = train_dataset_new.vocab.lookup_tokens(top_indices)\n", " top_preds = list(zip(top_words, top_indices, top_probs))\n", "\n", " for word, _, prob in top_preds:\n", " if word != '':\n", " prediction += f'{word}:{prob} '\n", " total_prob += prob\n", " prediction += f':{1 - total_prob}'\n", " return prediction" ], "metadata": { "id": "5K9YlprQkUc8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", "model.load_state_dict(torch.load('model2.bin'))\n", "model.eval() " ], "metadata": { "id": "MgaRdbD8kUfd" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "with lzma.open(f'dev-0/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with open(f'dev-0/out-HIDDEN-SIZE={hidden_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n", " for line in fid:\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n", " output_line = predict_words(prefix)\n", " f.write(output_line + '\\n')" ], "metadata": { "id": "MoL-FV4rkgZB" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "with lzma.open(f'test-A/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with open(f'test-A/out-HIDDEN-SIZE={hidden_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n", " for line in fid:\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n", " output_line = predict_words(prefix)\n", " f.write(output_line + '\\n')" ], "metadata": { "id": "jHlOHc8Hkgbg" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "torch.save(model.state_dict(), 'model2.bin')" ], "metadata": { "id": "CcX31HX1kgd4" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "DhbNd_O8koQv" }, "execution_count": null, "outputs": [] } ] }