challenging-america-word-ga.../zad8_trigrams_nn.ipynb

329 lines
9.3 KiB
Plaintext
Raw Normal View History

2023-05-10 23:23:51 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import lzma\n",
"from itertools import islice\n",
"import re\n",
"import sys\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"from torch import nn\n",
"from torch.utils.data import IterableDataset, DataLoader\n",
"import itertools\n",
"import matplotlib.pyplot as plt"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"VOCAB_SIZE = 2_000\n",
"EMBED_SIZE = 500"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" line = line.split(\"\\t\")\n",
" text = line[-2] + \" \" + line[-1]\n",
" text = re.sub(r\"\\\\+n\", \" \", text)\n",
" text = re.sub('[^A-Za-z ]+', '', text)\n",
" for t in text.split():\n",
" yield t"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_word_lines_from_file(file_name):\n",
" with lzma.open(file_name, encoding='utf8', mode=\"rt\") as fh:\n",
" for line in fh:\n",
" yield get_words_from_line(line)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def look_ahead_iterator(gen):\n",
" first = None\n",
" second = None\n",
" for item in gen:\n",
" if first is not None and second is not None:\n",
" yield ((first, item), second)\n",
" first = second\n",
" second = item"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(\"train/in.tsv.xz\"),\n",
" max_tokens = VOCAB_SIZE,\n",
" specials = ['<unk>'])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Trigram class"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Trigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = vocab\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = VOCAB_SIZE\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
"\n",
"train_dataset = Trigrams(\"train/in.tsv.xz\", VOCAB_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TrigramNNModel(nn.Module):\n",
" def __init__(self, VOCAB_SIZE, EMBED_SIZE):\n",
" super(TrigramNNModel, self).__init__()\n",
" self.embeddings = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)\n",
" self.hidden_layer = nn.Linear(EMBED_SIZE*2, 1200)\n",
" self.output_layer = nn.Linear(1200, VOCAB_SIZE)\n",
" self.softmax = nn.Softmax()\n",
"\n",
" def forward(self, x):\n",
" emb_2 = self.embeddings(x[0])\n",
" emb_1 = self.embeddings(x[1])\n",
" x = torch.cat([emb_2, emb_1], dim=1)\n",
" x = self.hidden_layer(x)\n",
" x = self.output_layer(x)\n",
" x = self.softmax(x)\n",
" return x\n",
"\n",
"model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE)\n",
"\n",
"vocab.set_default_index(vocab['<unk>'])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'\n",
"model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
"data = DataLoader(train_dataset, batch_size=1_000)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()\n",
"\n",
"loss_track = []\n",
"last_loss = 1_000\n",
"trigger_count = 0\n",
"\n",
"model.train()\n",
"step = 0\n",
"for x, y in data:\n",
" x[0] = x[0].to(device)\n",
" x[1] = x[1].to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 100 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if loss > last_loss:\n",
" trigger_count += 1 \n",
" print(trigger_count, 'LOSS DIFF:', loss, last_loss)\n",
"\n",
" if trigger_count >= 500:\n",
" break\n",
"\n",
" loss_track.append(loss)\n",
" last_loss = loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), f'model_trigram-EMBED_SIZE={EMBED_SIZE}.bin')\n",
"vocab_unique = set(vocab.get_stoi().keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = []\n",
"pattern = re.compile('[^A-Za-z]+')\n",
"\n",
"with lzma.open(\"dev-0/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n",
" for line in file:\n",
" line = line.split(\"\\t\")\n",
" first_word = pattern.sub(' ', line[-2]).split()[-1]\n",
" second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n",
"\n",
" first_word = re.sub('[^A-Za-z]+', '', first_word)\n",
" second_word = re.sub('[^A-Za-z]+', '', second_word)\n",
"\n",
" first_word = \"<unk>\" if first_word not in vocab_unique else first_word\n",
" second_word = \"<unk>\" if second_word not in vocab_unique else second_word\n",
"\n",
" input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n",
" out = model(input_tokens)\n",
"\n",
" top = torch.topk(out[0], 10)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" unk_bonus = 1 - sum(top_probs)\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" top_zipped = list(zip(top_words, top_probs))\n",
"\n",
" res = \" \".join([f\"{w}:{p:.4f}\" if w != \"<unk>\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n",
" res += \"\\n\"\n",
" output.append(res)\n",
"\n",
"with open(f\"dev-0/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n",
" file.writelines(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = []\n",
"pattern = re.compile('[^A-Za-z]+')\n",
"\n",
"with lzma.open(\"test-A/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n",
" for line in file:\n",
" line = line.split(\"\\t\")\n",
" first_word = pattern.sub(' ', line[-2]).split()[-1]\n",
" second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n",
"\n",
" first_word = re.sub('[^A-Za-z]+', '', first_word)\n",
" second_word = re.sub('[^A-Za-z]+', '', second_word)\n",
"\n",
" first_word = \"<unk>\" if first_word not in vocab_unique else first_word\n",
" second_word = \"<unk>\" if second_word not in vocab_unique else second_word\n",
"\n",
" input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n",
" out = model(input_tokens)\n",
"\n",
" top = torch.topk(out[0], 10)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" unk_bonus = 1 - sum(top_probs)\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" top_zipped = list(zip(top_words, top_probs))\n",
"\n",
" res = \" \".join([f\"{w}:{p:.4f}\" if w != \"<unk>\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n",
" res += \"\\n\"\n",
" output.append(res)\n",
"\n",
"with open(f\"test-A/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n",
" file.writelines(output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}