challenging-america-word-ga.../neural-networks/pytorch_n_gram.ipynb
2023-04-28 01:49:33 +02:00

646 lines
30 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KYySXV60UbL4",
"outputId": "bb7d4752-ccc2-48cf-f5d4-540ffbcc1243"
},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f3f2c178990>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"torch.manual_seed(1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wle6wRL0UbL9"
},
"source": [
"### Data loading"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "63t_aV_IUbL-"
},
"outputs": [],
"source": [
"import pickle\n",
"import lzma\n",
"import regex as re\n",
"\n",
"\n",
"def load_pickle(filename):\n",
" with open(filename, \"rb\") as f:\n",
" return pickle.load(f)\n",
"\n",
"\n",
"def save_pickle(d):\n",
" with open(\"vocabulary.pkl\", \"wb\") as f:\n",
" pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
"\n",
"\n",
"def clean_document(document: str) -> str:\n",
" document = document.lower().replace(\"\", \"'\")\n",
" document = re.sub(r\"'s|[\\-­]\\\\n\", \"\", document)\n",
" document = re.sub(\n",
" r\"(\\\\+n|[{}\\[\\]”&:•¦()*0-9;\\\"«»$\\-><^,®¬¿?¡!#+. \\t\\n])+\", \" \", document\n",
" )\n",
" for to_find, substitute in zip(\n",
" [\"i'm\", \"won't\", \"n't\", \"'ll\"], [\"i am\", \"will not\", \" not\", \" will\"]\n",
" ):\n",
" document = document.replace(to_find, substitute)\n",
" return document\n",
"\n",
"\n",
"def get_words_from_line(line, clean_text=True):\n",
" if clean_text:\n",
" line = clean_document(line) # .rstrip()\n",
" else:\n",
" line = line.strip()\n",
" yield \"<s>\"\n",
" for m in re.finditer(r\"[\\p{L}0-9\\*]+|\\p{P}+\", line):\n",
" yield m.group(0).lower()\n",
" yield \"</s>\"\n",
"\n",
"\n",
"def get_word_lines_from_file(file_name, clean_text=True, only_text=False):\n",
" with lzma.open(file_name, \"r\") as fh:\n",
" for i, line in enumerate(fh):\n",
" if only_text:\n",
" line = \"\\t\".join(line.decode(\"utf-8\").split(\"\\t\")[:-2])\n",
" else:\n",
" line = line.decode(\"utf-8\")\n",
" if i % 10000 == 0:\n",
" print(i)\n",
" yield get_words_from_line(line, clean_text)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N7HGIM40UbL-"
},
"source": [
"### Dataclasses"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "50ezMyyNUbL_",
"outputId": "88af632f-fa88-43ea-ea7e-a709a2428c11"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"10000\n",
"20000\n",
"30000\n",
"40000\n",
"50000\n",
"60000\n",
"70000\n",
"80000\n",
"90000\n",
"100000\n",
"110000\n",
"120000\n",
"130000\n",
"140000\n",
"150000\n",
"160000\n",
"170000\n",
"180000\n",
"190000\n",
"200000\n",
"210000\n",
"220000\n",
"230000\n",
"240000\n",
"250000\n",
"260000\n",
"270000\n",
"280000\n",
"290000\n",
"300000\n",
"310000\n",
"320000\n",
"330000\n",
"340000\n",
"350000\n",
"360000\n",
"370000\n",
"380000\n",
"390000\n",
"400000\n",
"410000\n",
"420000\n",
"430000\n"
]
}
],
"source": [
"from torch.utils.data import IterableDataset\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"import itertools\n",
"\n",
"\n",
"VOCAB_SIZE = 20000\n",
"\n",
"\n",
"def look_ahead_iterator(gen):\n",
" prev = None\n",
" for item in gen:\n",
" if prev is not None:\n",
" yield (prev, item)\n",
" prev = item\n",
"\n",
"\n",
"class Bigrams(IterableDataset):\n",
" def __init__(\n",
" self, text_file, vocabulary_size, vocab=None, only_text=False, clean_text=True\n",
" ):\n",
" self.vocab = (\n",
" build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file, clean_text, only_text),\n",
" max_tokens=vocabulary_size,\n",
" specials=[\"<unk>\"],\n",
" )\n",
" if vocab is None\n",
" else vocab\n",
" )\n",
" self.vocab.set_default_index(self.vocab[\"<unk>\"])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
" self.clean_text = clean_text\n",
" self.only_text = only_text\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (\n",
" self.vocab[t]\n",
" for t in itertools.chain.from_iterable(\n",
" get_word_lines_from_file(\n",
" self.text_file, self.clean_text, self.only_text\n",
" )\n",
" )\n",
" )\n",
" )\n",
"\n",
"\n",
"vocab = None # torch.load('./vocab.pth')\n",
"\n",
"train_dataset = Bigrams(\"/content/train/in.tsv.xz\", VOCAB_SIZE, vocab, clean_text=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "SCZ87kVxUbL_"
},
"outputs": [],
"source": [
"# torch.save(train_dataset.vocab, \"vocab.pth\")\n",
"# torch.save(train_dataset.vocab, \"vocab_only_text.pth\")\n",
"# torch.save(train_dataset.vocab, \"vocab_only_text_clean.pth\")\n",
"torch.save(train_dataset.vocab, \"vocab_2.pth\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bv_Adw_lUbMA"
},
"source": [
"### Model definition"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "oaoLz3jPUbMB"
},
"outputs": [],
"source": [
"class SimpleBigramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Embedding(vocabulary_size, embedding_size),\n",
" nn.Linear(embedding_size, vocabulary_size),\n",
" nn.Softmax(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "YBDf3nEvUbMC"
},
"outputs": [],
"source": [
"EMBED_SIZE = 100\n",
"\n",
"model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "CRe3STJUUbMC",
"outputId": "ab4b234a-49c7-4878-f8e4-e21072efa9ee"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" input = module(input)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.0674, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"100 tensor(8.4352, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"200 tensor(7.6662, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"300 tensor(7.0716, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"400 tensor(6.6710, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"500 tensor(6.4540, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"600 tensor(5.9974, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"700 tensor(5.7973, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"800 tensor(5.8026, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000\n",
"900 tensor(5.7118, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1000 tensor(5.7471, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1100 tensor(5.6865, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1200 tensor(5.4205, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1300 tensor(5.4954, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1400 tensor(5.5415, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1500 tensor(5.3322, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1600 tensor(5.4665, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1700 tensor(5.4710, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"20000\n",
"1800 tensor(5.3953, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1900 tensor(5.4881, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2000 tensor(5.4915, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2100 tensor(5.3621, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2200 tensor(5.2872, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2300 tensor(5.2590, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2400 tensor(5.3661, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2500 tensor(5.3305, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"30000\n",
"2600 tensor(5.3789, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2700 tensor(5.3548, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2800 tensor(5.4579, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2900 tensor(5.2660, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3000 tensor(5.3253, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3100 tensor(5.4020, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3200 tensor(5.2962, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3300 tensor(5.2570, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3400 tensor(5.2317, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"40000\n",
"3500 tensor(5.2410, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3600 tensor(5.2404, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3700 tensor(5.1738, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3800 tensor(5.2654, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3900 tensor(5.2595, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4000 tensor(5.2850, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4100 tensor(5.2995, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4200 tensor(5.2581, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4300 tensor(5.3323, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"50000\n",
"4400 tensor(5.2498, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4500 tensor(5.2674, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4600 tensor(5.3033, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4700 tensor(5.2066, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4800 tensor(5.2302, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4900 tensor(5.2617, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(5.2306, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5100 tensor(5.2781, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"60000\n",
"5200 tensor(5.1833, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5300 tensor(5.2166, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5400 tensor(5.0845, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5500 tensor(5.2272, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5600 tensor(5.3175, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5700 tensor(5.2425, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5800 tensor(5.2449, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5900 tensor(5.3225, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6000 tensor(5.2786, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"70000\n",
"6100 tensor(5.1489, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6200 tensor(5.1793, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6300 tensor(5.2194, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6400 tensor(5.1708, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6500 tensor(5.1394, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6600 tensor(5.1280, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6700 tensor(5.0869, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6800 tensor(5.3255, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6900 tensor(5.3426, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"80000\n",
"7000 tensor(5.1176, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7100 tensor(5.1991, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7200 tensor(5.1227, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7300 tensor(5.1744, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7400 tensor(5.2222, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7500 tensor(5.2110, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7600 tensor(5.1553, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7700 tensor(5.3283, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"90000\n",
"7800 tensor(5.2544, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7900 tensor(5.1871, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8000 tensor(5.2215, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8100 tensor(5.1744, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8200 tensor(5.1087, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8300 tensor(5.1639, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8400 tensor(5.1604, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8500 tensor(5.1612, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8600 tensor(5.2307, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"100000\n",
"8700 tensor(5.1648, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8800 tensor(5.1066, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8900 tensor(5.2405, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"9000 tensor(5.2184, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"9100 tensor(5.2677, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"9200 tensor(5.0773, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-c690ed9ba7ad>\u001b[0m in \u001b[0;36m<cell line: 11>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"device = \"cuda\"\n",
"model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
"data = DataLoader(train_dataset, batch_size=5000)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()\n",
"\n",
"model.train()\n",
"step = 0\n",
"for x, y in data:\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 100 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"torch.save(model.state_dict(), \"model_2.bin\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kS9NHTGeom3y",
"outputId": "ce83640e-d2fe-41e6-cd5d-38a5b0410323"
},
"outputs": [
{
"data": {
"text/plain": [
"[('the', 2, 0.15899169445037842),\n",
" ('\\\\', 1, 0.10546761751174927),\n",
" ('he', 28, 0.06849857419729233),\n",
" ('it', 15, 0.05329886078834534),\n",
" ('i', 26, 0.0421920120716095),\n",
" ('they', 50, 0.03895237296819687),\n",
" ('a', 8, 0.03352600708603859),\n",
" ('<unk>', 0, 0.031062396243214607),\n",
" ('we', 61, 0.02323235757648945),\n",
" ('she', 104, 0.02003088779747486)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ixs = torch.tensor(train_dataset.vocab.forward([\"when\"])).to(device)\n",
"out = model(ixs)\n",
"top = torch.topk(out[0], 10)\n",
"top_indices = top.indices.tolist()\n",
"top_probs = top.values.tolist()\n",
"top_words = train_dataset.vocab.lookup_tokens(top_indices)\n",
"list(zip(top_words, top_indices, top_probs))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UTYHTCJvC_Nm",
"outputId": "576288fa-fdad-4c21-d924-f613eaf33063"
},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = \"cuda\"\n",
"model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
"model.load_state_dict(torch.load(\"model1.bin\"))\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8WexjGIAxaE4",
"outputId": "52252b81-3b98-42d3-b137-472af00dbb26"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training on /content/dev-0/in.tsv.xz\n",
"\rProgress: 0.01%\rProgress: 0.02%\rProgress: 0.03%\rProgress: 0.04%\rProgress: 0.05%\rProgress: 0.06%\rProgress: 0.07%\rProgress: 0.08%\rProgress: 0.09%\rProgress: 0.10%\rProgress: 0.10%\rProgress: 0.11%\rProgress: 0.12%\rProgress: 0.13%\rProgress: 0.14%\rProgress: 0.15%\rProgress: 0.16%\rProgress: 0.17%\rProgress: 0.18%\rProgress: 0.19%\rProgress: 0.20%\rProgress: 0.21%\rProgress: 0.22%\rProgress: 0.23%\rProgress: 0.24%\rProgress: 0.25%\rProgress: 0.26%\rProgress: 0.27%\rProgress: 0.28%\rProgress: 0.29%\rProgress: 0.29%\rProgress: 0.30%\rProgress: 0.31%\rProgress: 0.32%\rProgress: 0.33%\rProgress: 0.34%\rProgress: 0.35%\rProgress: 0.36%\rProgress: 0.37%\rProgress: 0.38%\rProgress: 0.39%\rProgress: 0.40%\rProgress: 0.41%\rProgress: 0.42%\rProgress: 0.43%\rProgress: 0.44%\rProgress: 0.45%\rProgress: 0.46%\rProgress: 0.47%\rProgress: 0.48%\rProgress: 0.48%\rProgress: 0.49%\rProgress: 0.50%\rProgress: 0.51%\rProgress: 0.52%\rProgress: 0.53%\rProgress: 0.54%\rProgress: 0.55%\rProgress: 0.56%\rProgress: 0.57%\rProgress: 0.58%\rProgress: 0.59%\rProgress: 0.60%\rProgress: 0.61%\rProgress: 0.62%\rProgress: 0.63%\rProgress: 0.64%\rProgress: 0.65%\rProgress: 0.66%\rProgress: 0.67%\rProgress: 0.67%\rProgress: 0.68%\rProgress: 0.69%\rProgress: 0.70%\rProgress: 0.71%\rProgress: 0.72%\rProgress: 0.73%\rProgress: 0.74%\rProgress: 0.75%\rProgress: 0.76%\rProgress: 0.77%\rProgress: 0.78%\rProgress: 0.79%\rProgress: 0.80%\rProgress: 0.81%\rProgress: 0.82%\rProgress: 0.83%\rProgress: 0.84%\rProgress: 0.85%\rProgress: 0.86%\rProgress: 0.87%\rProgress: 0.87%\rProgress: 0.88%\rProgress: 0.89%\rProgress: 0.90%\rProgress: 0.91%\rProgress: 0.92%\rProgress: 0.93%\rProgress: 0.94%\rProgress: 0.95%\rProgress: 0.96%\rProgress: 0.97%\rProgress: 0.98%\rProgress: 0.99%\rProgress: 1.00%\rProgress: 1.01%\rProgress: 1.02%\rProgress: 1.03%\rProgress: 1.04%\rProgress: 1.05%\rProgress: 1.06%\rProgress: 1.06%\rProgress: 1.07%\rProgress: 1.08%\rProgress: 1.09%\rProgress: 1.10%\rProgress: 1.11%\rProgress: 1.12%\rProgress: 1.13%\rProgress: 1.14%\rProgress: 1.15%\rProgress: 1.16%\rProgress: 1.17%\rProgress: 1.18%\rProgress: 1.19%\rProgress: 1.20%\rProgress: 1.21%\rProgress: 1.22%\rProgress: 1.23%\rProgress: 1.24%\rProgress: 1.25%\rProgress: 1.25%\rProgress: 1.26%\rProgress: 1.27%\rProgress: 1.28%\rProgress: 1.29%\rProgress: 1.30%\rProgress: 1.31%\rProgress: 1.32%\rProgress: 1.33%\rProgress: 1.34%\rProgress: 1.35%\rProgress: 1.36%\rProgress: 1.37%\rProgress: 1.38%\rProgress: 1.39%\rProgress: 1.40%\rProgress: 1.41%\rProgress: 1.42%\rProgress: 1.43%\rProgress: 1.44%\rProgress: 1.45%\rProgress: 1.45%\rProgress: 1.46%\rProgress: 1.47%\rProgress: 1.48%\rProgress: 1.49%\rProgress: 1.50%\rProgress: 1.51%\rProgress: 1.52%\rProgress: 1.53%\rProgress: 1.54%\rProgress: 1.55%\rProgress: 1.56%\rProgress: 1.57%\rProgress: 1.58%\rProgress: 1.59%\rProgress: 1.60%\rProgress: 1.61%\rProgress: 1.62%"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" input = module(input)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Progress: 100.00%\n",
"Training on /content/test-A/in.tsv.xz\n",
"Progress: 100.00%\n"
]
}
],
"source": [
"def predict_word(ixs, model, top_k=5):\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], 10)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = train_dataset.vocab.lookup_tokens(top_indices)\n",
" return list(zip(top_words, top_indices, top_probs))\n",
"\n",
"\n",
"def get_one_word(text, context=\"left\"):\n",
" # print(\"Getting word from:\", text)\n",
" if context == \"left\":\n",
" context = -1\n",
" else:\n",
" context = 0\n",
" return text.rstrip().split(\" \")[context]\n",
"\n",
"\n",
"def inference_on_file(filename, model, lines_no=1):\n",
" results_path = \"/\".join(filename.split(\"/\")[:-1]) + \"/out.tsv\"\n",
" with lzma.open(filename, \"r\") as fp, open(results_path, \"w\") as out_file:\n",
" print(\"Training on\", filename)\n",
" for i, line in enumerate(fp):\n",
" # left, right = [ get_one_word(text_part, context)\n",
" # for context, text_part in zip(line.split('\\t')[:-2], ('left', 'right'))]\n",
" line = line.decode(\"utf-8\")\n",
" # print(line)\n",
" left = get_one_word(line.split(\"\\t\")[-2])\n",
" # print(\"Current word:\", left)\n",
" tensor = torch.tensor(train_dataset.vocab.forward([left])).to(device)\n",
" results = predict_word(tensor, model, 9)\n",
" prob_sum = sum([word[2] for word in results])\n",
" result_line = (\n",
" \" \".join([f\"{word[0]}:{word[2]}\" for word in results])\n",
" + f\" :{prob_sum}\\n\"\n",
" )\n",
" # print(result_line)\n",
" out_file.write(result_line)\n",
" print(f\"\\rProgress: {(((i+1) / lines_no) * 100):.2f}%\", end=\"\")\n",
" print()\n",
"\n",
"\n",
"model.eval()\n",
"\n",
"for filepath, lines_no in zip(\n",
" (\"/content/dev-0/in.tsv.xz\", \"/content/test-A/in.tsv.xz\"), (10519.0, 7414.0)\n",
"):\n",
" inference_on_file(filepath, model, lines_no)\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "mj_venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 0
}