challenging-america-word-ga.../neural-trigram.ipynb
2022-05-08 19:33:06 +02:00

210 lines
7.0 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "dfeb7061",
"metadata": {},
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"from torch.utils.data import DataLoader\n",
"import torch\n",
"from torch import nn\n",
"import pandas as pd\n",
"import nltk\n",
"import regex as re\n",
"import csv\n",
"import itertools\n",
"from nltk import word_tokenize\n",
"from os.path import exists\n",
"\n",
"\n",
"def clean(text):\n",
" text = str(text).strip().lower()\n",
" text = re.sub(\"|>|<|\\.|\\\\\\\\|\\\"|”|-|,|\\*|:|\\/\", \"\", text)\n",
" text = text.replace('\\\\\\\\n', \" \").replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n",
" text = text.replace(\"'\", \"\")\n",
" return text\n",
"\n",
"def get_words_from_line(line, specials = True):\n",
" line = line.rstrip()\n",
" if specials:\n",
" yield '<s>'\n",
" for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n",
" yield m.group(0).lower()\n",
" if specials:\n",
" yield '</s>'\n",
"\n",
"def get_word_lines_from_data(d):\n",
" for line in d:\n",
" yield get_words_from_line(line)\n",
"\n",
"\n",
"class SimpleBigramNeuralLanguageModel(torch.nn.Module):\n",
" \n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Embedding(vocabulary_size, embedding_size),\n",
" nn.Linear(embedding_size, vocabulary_size),\n",
" nn.Softmax()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
"def look_ahead_iterator(gen):\n",
" w1 = None\n",
" for item in gen:\n",
" if w1 is not None:\n",
" yield (w1, item)\n",
" w1 = item\n",
" \n",
"class Bigrams(torch.utils.data.IterableDataset):\n",
" def __init__(self, data, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_data(data),\n",
" max_tokens = vocabulary_size,\n",
" specials = ['<unk>'])\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.data = data\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_data(self.data))))\n",
"\n",
"\n",
"# ładowanie danych treningowych\n",
"in_file = 'train/in.tsv.xz'\n",
"out_file = 'train/expected.tsv'\n",
"\n",
"X_train = pd.read_csv(in_file, sep='\\t', header=None, quoting=csv.QUOTE_NONE, nrows=200000, on_bad_lines=\"skip\", encoding=\"UTF-8\")\n",
"Y_train = pd.read_csv(out_file, sep='\\t', header=None, quoting=csv.QUOTE_NONE, nrows=200000, on_bad_lines=\"skip\", encoding=\"UTF-8\")\n",
"\n",
"X_train = X_train[[6, 7]]\n",
"X_train = pd.concat([X_train, Y_train], axis=1)\n",
"X_train = X_train[6] + X_train[0] + X_train[7]\n",
"X_train = X_train.apply(clean)\n",
"vocab_size = 30000\n",
"embed_size = 150\n",
"Dataset = Bigrams(X_train, vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1cc73f1e",
"metadata": {},
"outputs": [],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"\n",
"if(not exists('nn_model2.bin')):\n",
" data = DataLoader(Dataset, batch_size=8000)\n",
" optimizer = torch.optim.Adam(model.parameters())\n",
" criterion = torch.nn.NLLLoss()\n",
"\n",
" model.train()\n",
" step = 0\n",
" for i in range(2):\n",
" print(f\" Epoka {i}--------------------------------------------------------\")\n",
" for x, y in data:\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 100 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" torch.save(model.state_dict(), 'nn_model2.bin')\n",
"else:\n",
" model.load_state_dict(torch.load('nn_model2.bin')) \n",
"\n",
"\n",
"vocab = Dataset.vocab\n",
"\n",
"\n",
"# nltk.download('punkt')\n",
"def predict_word(ws):\n",
" ixs = torch.tensor(vocab.forward(ws)).to(device)\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], 8)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" pred_str = \"\"\n",
" for word, prob in list(zip(top_words, top_probs)):\n",
" pred_str += f\"{word}:{prob} \"\n",
"# pred_str += f':0.01'\n",
" return pred_str\n",
"\n",
"\n",
"def word_gap_prediction(file):\n",
" X_test = pd.read_csv(f'{file}/in.tsv.xz', sep='\\t', header=None, quoting=csv.QUOTE_NONE, on_bad_lines='skip', encoding=\"UTF-8\")[6]\n",
" X_test = X_test.apply(clean)\n",
" with open(f'{file}/out.tsv', \"w+\", encoding=\"UTF-8\") as f:\n",
" for row in X_test:\n",
" result = {}\n",
" before = None\n",
" for before in get_words_from_line(clean(str(row)), False):\n",
" pass\n",
" before = [before]\n",
" if(len(before) < 1):\n",
" pred_str = \"a:0.2 the:0.2 to:0.2 of:0.1 and:0.1 of:0.1 :0.1\"\n",
" else:\n",
" pred_str = predict_word(before)\n",
" pred_str = pred_str.strip()\n",
" f.write(pred_str + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "682d3528",
"metadata": {},
"outputs": [],
"source": [
"word_gap_prediction(\"dev-0/\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74b9f66c",
"metadata": {},
"outputs": [],
"source": [
"word_gap_prediction(\"test-A/\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}