{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import lzma\n", "from itertools import islice\n", "import re\n", "import sys\n", "from torchtext.vocab import build_vocab_from_iterator\n", "from torch import nn\n", "from torch.utils.data import IterableDataset, DataLoader\n", "import itertools\n", "import matplotlib.pyplot as plt" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "VOCAB_SIZE = 2_000\n", "EMBED_SIZE = 500" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_words_from_line(line):\n", " line = line.rstrip()\n", " line = line.split(\"\\t\")\n", " text = line[-2] + \" \" + line[-1]\n", " text = re.sub(r\"\\\\+n\", \" \", text)\n", " text = re.sub('[^A-Za-z ]+', '', text)\n", " for t in text.split():\n", " yield t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_word_lines_from_file(file_name):\n", " with lzma.open(file_name, encoding='utf8', mode=\"rt\") as fh:\n", " for line in fh:\n", " yield get_words_from_line(line)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def look_ahead_iterator(gen):\n", " first = None\n", " second = None\n", " for item in gen:\n", " if first is not None and second is not None:\n", " yield ((first, item), second)\n", " first = second\n", " second = item" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Create Vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(\"train/in.tsv.xz\"),\n", " max_tokens = VOCAB_SIZE,\n", " specials = [''])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Trigram class" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Trigrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = vocab\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = VOCAB_SIZE\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n", "\n", "train_dataset = Trigrams(\"train/in.tsv.xz\", VOCAB_SIZE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TrigramNNModel(nn.Module):\n", " def __init__(self, VOCAB_SIZE, EMBED_SIZE):\n", " super(TrigramNNModel, self).__init__()\n", " self.embeddings = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)\n", " self.hidden_layer = nn.Linear(EMBED_SIZE*2, 1200)\n", " self.output_layer = nn.Linear(1200, VOCAB_SIZE)\n", " self.softmax = nn.Softmax()\n", "\n", " def forward(self, x):\n", " emb_2 = self.embeddings(x[0])\n", " emb_1 = self.embeddings(x[1])\n", " x = torch.cat([emb_2, emb_1], dim=1)\n", " x = self.hidden_layer(x)\n", " x = self.output_layer(x)\n", " x = self.softmax(x)\n", " return x\n", "\n", "model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE)\n", "\n", "vocab.set_default_index(vocab[''])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = 'cpu'\n", "model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n", "data = DataLoader(train_dataset, batch_size=1_000)\n", "optimizer = torch.optim.Adam(model.parameters())\n", "criterion = torch.nn.NLLLoss()\n", "\n", "loss_track = []\n", "last_loss = 1_000\n", "trigger_count = 0\n", "\n", "model.train()\n", "step = 0\n", "for x, y in data:\n", " x[0] = x[0].to(device)\n", " x[1] = x[1].to(device)\n", " y = y.to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model(x)\n", " loss = criterion(torch.log(ypredicted), y)\n", " if step % 100 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if loss > last_loss:\n", " trigger_count += 1 \n", " print(trigger_count, 'LOSS DIFF:', loss, last_loss)\n", "\n", " if trigger_count >= 500:\n", " break\n", "\n", " loss_track.append(loss)\n", " last_loss = loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), f'model_trigram-EMBED_SIZE={EMBED_SIZE}.bin')\n", "vocab_unique = set(vocab.get_stoi().keys())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output = []\n", "pattern = re.compile('[^A-Za-z]+')\n", "\n", "with lzma.open(\"dev-0/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n", " for line in file:\n", " line = line.split(\"\\t\")\n", " first_word = pattern.sub(' ', line[-2]).split()[-1]\n", " second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n", "\n", " first_word = re.sub('[^A-Za-z]+', '', first_word)\n", " second_word = re.sub('[^A-Za-z]+', '', second_word)\n", "\n", " first_word = \"\" if first_word not in vocab_unique else first_word\n", " second_word = \"\" if second_word not in vocab_unique else second_word\n", "\n", " input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n", " out = model(input_tokens)\n", "\n", " top = torch.topk(out[0], 10)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " unk_bonus = 1 - sum(top_probs)\n", " top_words = vocab.lookup_tokens(top_indices)\n", " top_zipped = list(zip(top_words, top_probs))\n", "\n", " res = \" \".join([f\"{w}:{p:.4f}\" if w != \"\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n", " res += \"\\n\"\n", " output.append(res)\n", "\n", "with open(f\"dev-0/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n", " file.writelines(output)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output = []\n", "pattern = re.compile('[^A-Za-z]+')\n", "\n", "with lzma.open(\"test-A/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n", " for line in file:\n", " line = line.split(\"\\t\")\n", " first_word = pattern.sub(' ', line[-2]).split()[-1]\n", " second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n", "\n", " first_word = re.sub('[^A-Za-z]+', '', first_word)\n", " second_word = re.sub('[^A-Za-z]+', '', second_word)\n", "\n", " first_word = \"\" if first_word not in vocab_unique else first_word\n", " second_word = \"\" if second_word not in vocab_unique else second_word\n", "\n", " input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n", " out = model(input_tokens)\n", "\n", " top = torch.topk(out[0], 10)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " unk_bonus = 1 - sum(top_probs)\n", " top_words = vocab.lookup_tokens(top_indices)\n", " top_zipped = list(zip(top_words, top_probs))\n", "\n", " res = \" \".join([f\"{w}:{p:.4f}\" if w != \"\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n", " res += \"\\n\"\n", " output.append(res)\n", "\n", "with open(f\"test-A/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n", " file.writelines(output)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }