329 lines
9.3 KiB
Plaintext
329 lines
9.3 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import lzma\n",
|
||
|
"from itertools import islice\n",
|
||
|
"import re\n",
|
||
|
"import sys\n",
|
||
|
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||
|
"from torch import nn\n",
|
||
|
"from torch.utils.data import IterableDataset, DataLoader\n",
|
||
|
"import itertools\n",
|
||
|
"import matplotlib.pyplot as plt"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Parameters"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"VOCAB_SIZE = 2_000\n",
|
||
|
"EMBED_SIZE = 500"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Functions"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_words_from_line(line):\n",
|
||
|
" line = line.rstrip()\n",
|
||
|
" line = line.split(\"\\t\")\n",
|
||
|
" text = line[-2] + \" \" + line[-1]\n",
|
||
|
" text = re.sub(r\"\\\\+n\", \" \", text)\n",
|
||
|
" text = re.sub('[^A-Za-z ]+', '', text)\n",
|
||
|
" for t in text.split():\n",
|
||
|
" yield t"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_word_lines_from_file(file_name):\n",
|
||
|
" with lzma.open(file_name, encoding='utf8', mode=\"rt\") as fh:\n",
|
||
|
" for line in fh:\n",
|
||
|
" yield get_words_from_line(line)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def look_ahead_iterator(gen):\n",
|
||
|
" first = None\n",
|
||
|
" second = None\n",
|
||
|
" for item in gen:\n",
|
||
|
" if first is not None and second is not None:\n",
|
||
|
" yield ((first, item), second)\n",
|
||
|
" first = second\n",
|
||
|
" second = item"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Create Vocab"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"vocab = build_vocab_from_iterator(\n",
|
||
|
" get_word_lines_from_file(\"train/in.tsv.xz\"),\n",
|
||
|
" max_tokens = VOCAB_SIZE,\n",
|
||
|
" specials = ['<unk>'])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Trigram class"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Trigrams(IterableDataset):\n",
|
||
|
" def __init__(self, text_file, vocabulary_size):\n",
|
||
|
" self.vocab = vocab\n",
|
||
|
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||
|
" self.vocabulary_size = VOCAB_SIZE\n",
|
||
|
" self.text_file = text_file\n",
|
||
|
"\n",
|
||
|
" def __iter__(self):\n",
|
||
|
" return look_ahead_iterator(\n",
|
||
|
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||
|
"\n",
|
||
|
"train_dataset = Trigrams(\"train/in.tsv.xz\", VOCAB_SIZE)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class TrigramNNModel(nn.Module):\n",
|
||
|
" def __init__(self, VOCAB_SIZE, EMBED_SIZE):\n",
|
||
|
" super(TrigramNNModel, self).__init__()\n",
|
||
|
" self.embeddings = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)\n",
|
||
|
" self.hidden_layer = nn.Linear(EMBED_SIZE*2, 1200)\n",
|
||
|
" self.output_layer = nn.Linear(1200, VOCAB_SIZE)\n",
|
||
|
" self.softmax = nn.Softmax()\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" emb_2 = self.embeddings(x[0])\n",
|
||
|
" emb_1 = self.embeddings(x[1])\n",
|
||
|
" x = torch.cat([emb_2, emb_1], dim=1)\n",
|
||
|
" x = self.hidden_layer(x)\n",
|
||
|
" x = self.output_layer(x)\n",
|
||
|
" x = self.softmax(x)\n",
|
||
|
" return x\n",
|
||
|
"\n",
|
||
|
"model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE)\n",
|
||
|
"\n",
|
||
|
"vocab.set_default_index(vocab['<unk>'])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Training"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"device = 'cpu'\n",
|
||
|
"model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
|
||
|
"data = DataLoader(train_dataset, batch_size=1_000)\n",
|
||
|
"optimizer = torch.optim.Adam(model.parameters())\n",
|
||
|
"criterion = torch.nn.NLLLoss()\n",
|
||
|
"\n",
|
||
|
"loss_track = []\n",
|
||
|
"last_loss = 1_000\n",
|
||
|
"trigger_count = 0\n",
|
||
|
"\n",
|
||
|
"model.train()\n",
|
||
|
"step = 0\n",
|
||
|
"for x, y in data:\n",
|
||
|
" x[0] = x[0].to(device)\n",
|
||
|
" x[1] = x[1].to(device)\n",
|
||
|
" y = y.to(device)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" ypredicted = model(x)\n",
|
||
|
" loss = criterion(torch.log(ypredicted), y)\n",
|
||
|
" if step % 100 == 0:\n",
|
||
|
" print(step, loss)\n",
|
||
|
" step += 1\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
"\n",
|
||
|
" if loss > last_loss:\n",
|
||
|
" trigger_count += 1 \n",
|
||
|
" print(trigger_count, 'LOSS DIFF:', loss, last_loss)\n",
|
||
|
"\n",
|
||
|
" if trigger_count >= 500:\n",
|
||
|
" break\n",
|
||
|
"\n",
|
||
|
" loss_track.append(loss)\n",
|
||
|
" last_loss = loss"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"torch.save(model.state_dict(), f'model_trigram-EMBED_SIZE={EMBED_SIZE}.bin')\n",
|
||
|
"vocab_unique = set(vocab.get_stoi().keys())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"output = []\n",
|
||
|
"pattern = re.compile('[^A-Za-z]+')\n",
|
||
|
"\n",
|
||
|
"with lzma.open(\"dev-0/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n",
|
||
|
" for line in file:\n",
|
||
|
" line = line.split(\"\\t\")\n",
|
||
|
" first_word = pattern.sub(' ', line[-2]).split()[-1]\n",
|
||
|
" second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n",
|
||
|
"\n",
|
||
|
" first_word = re.sub('[^A-Za-z]+', '', first_word)\n",
|
||
|
" second_word = re.sub('[^A-Za-z]+', '', second_word)\n",
|
||
|
"\n",
|
||
|
" first_word = \"<unk>\" if first_word not in vocab_unique else first_word\n",
|
||
|
" second_word = \"<unk>\" if second_word not in vocab_unique else second_word\n",
|
||
|
"\n",
|
||
|
" input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n",
|
||
|
" out = model(input_tokens)\n",
|
||
|
"\n",
|
||
|
" top = torch.topk(out[0], 10)\n",
|
||
|
" top_indices = top.indices.tolist()\n",
|
||
|
" top_probs = top.values.tolist()\n",
|
||
|
" unk_bonus = 1 - sum(top_probs)\n",
|
||
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||
|
" top_zipped = list(zip(top_words, top_probs))\n",
|
||
|
"\n",
|
||
|
" res = \" \".join([f\"{w}:{p:.4f}\" if w != \"<unk>\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n",
|
||
|
" res += \"\\n\"\n",
|
||
|
" output.append(res)\n",
|
||
|
"\n",
|
||
|
"with open(f\"dev-0/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n",
|
||
|
" file.writelines(output)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"output = []\n",
|
||
|
"pattern = re.compile('[^A-Za-z]+')\n",
|
||
|
"\n",
|
||
|
"with lzma.open(\"test-A/in.tsv.xz\", encoding='utf8', mode=\"rt\") as file:\n",
|
||
|
" for line in file:\n",
|
||
|
" line = line.split(\"\\t\")\n",
|
||
|
" first_word = pattern.sub(' ', line[-2]).split()[-1]\n",
|
||
|
" second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]\n",
|
||
|
"\n",
|
||
|
" first_word = re.sub('[^A-Za-z]+', '', first_word)\n",
|
||
|
" second_word = re.sub('[^A-Za-z]+', '', second_word)\n",
|
||
|
"\n",
|
||
|
" first_word = \"<unk>\" if first_word not in vocab_unique else first_word\n",
|
||
|
" second_word = \"<unk>\" if second_word not in vocab_unique else second_word\n",
|
||
|
"\n",
|
||
|
" input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)\n",
|
||
|
" out = model(input_tokens)\n",
|
||
|
"\n",
|
||
|
" top = torch.topk(out[0], 10)\n",
|
||
|
" top_indices = top.indices.tolist()\n",
|
||
|
" top_probs = top.values.tolist()\n",
|
||
|
" unk_bonus = 1 - sum(top_probs)\n",
|
||
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||
|
" top_zipped = list(zip(top_words, top_probs))\n",
|
||
|
"\n",
|
||
|
" res = \" \".join([f\"{w}:{p:.4f}\" if w != \"<unk>\" else f\":{(p + unk_bonus):.4f}\" for w, p in top_zipped])\n",
|
||
|
" res += \"\\n\"\n",
|
||
|
" output.append(res)\n",
|
||
|
"\n",
|
||
|
"with open(f\"test-A/out-EMBED_SIZE={EMBED_SIZE}.tsv\", mode=\"w\") as file:\n",
|
||
|
" file.writelines(output)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.2"
|
||
|
},
|
||
|
"orig_nbformat": 4
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|