348 lines
8.9 KiB
Plaintext
348 lines
8.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## IMPORTS"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"import regex as re\n",
|
|
"import sys\n",
|
|
"from torchtext.vocab import build_vocab_from_iterator\n",
|
|
"import lzma\n",
|
|
"from torch.utils.data import IterableDataset\n",
|
|
"import itertools\n",
|
|
"from torch import nn\n",
|
|
"import torch\n",
|
|
"import pickle\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"\n",
|
|
"print(torch.backends.mps.is_available())\n",
|
|
"print(torch.backends.mps.is_built())"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## FUNCTIONS"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_words_from_line(line):\n",
|
|
" line = line.rstrip()\n",
|
|
" yield '<s>'\n",
|
|
" for t in line.split(' '):\n",
|
|
" yield t\n",
|
|
" yield '</s>'\n",
|
|
"\n",
|
|
"def get_word_lines_from_file(file_name):\n",
|
|
" n = 0\n",
|
|
" with lzma.open(file_name, 'r') as fh:\n",
|
|
" for line in fh:\n",
|
|
" n += 1\n",
|
|
" if n % 1000 == 0:\n",
|
|
" print(n ,file=sys.stderr)\n",
|
|
" yield get_words_from_line(line.decode('utf-8'))\n",
|
|
"\n",
|
|
"def look_ahead_iterator(gen):\n",
|
|
" prev = None\n",
|
|
" for item in gen:\n",
|
|
" if prev is not None:\n",
|
|
" yield (prev, item)\n",
|
|
" prev = item\n",
|
|
"\n",
|
|
"def clean(text):\n",
|
|
" text = str(text).lower().replace('-\\\\n', '').replace('\\\\n', ' ').replace('-', '').replace('\\'s', ' is').replace('\\'re', ' are').replace('\\'m', ' am').replace('\\'ve', ' have').replace('\\'ll', ' will')\n",
|
|
" text = re.sub(r'\\p{P}', '', text)\n",
|
|
" return text\n",
|
|
"\n",
|
|
"def predict(word, model, vocab):\n",
|
|
" try:\n",
|
|
" ixs = torch.tensor(vocab.forward([word])).to(device)\n",
|
|
" except:\n",
|
|
" ixs = torch.tensor(vocab.forward(['<unk>'])).to(device)\n",
|
|
" word = '<unk>'\n",
|
|
" out = model(ixs)\n",
|
|
" top = torch.topk(out[0], 300)\n",
|
|
" top_indices = top.indices.tolist()\n",
|
|
" top_probs = top.values.tolist()\n",
|
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
|
" prob_list = list(zip(top_words, top_probs))\n",
|
|
" for index, element in enumerate(prob_list):\n",
|
|
" unk = None\n",
|
|
" if '<unk>' in element:\n",
|
|
" unk = prob_list.pop(index)\n",
|
|
" prob_list.append(('', unk[1]))\n",
|
|
" break\n",
|
|
" if unk is None:\n",
|
|
" prob_list[-1] = ('', prob_list[-1][1])\n",
|
|
" return ' '.join([f'{x[0]}:{x[1]}' for x in prob_list])\n",
|
|
"\n",
|
|
"def predicition_for_file(model, vocab, folder, file):\n",
|
|
" print('=' * 10, f' do prediction for {folder}/{file} ', '=' * 10)\n",
|
|
" with lzma.open(f'{folder}/in.tsv.xz', mode='rt', encoding='utf-8') as f:\n",
|
|
" with open(f'{folder}/out.tsv', 'w', encoding='utf-8') as fid:\n",
|
|
"\n",
|
|
" for line in f:\n",
|
|
" separated = line.split('\\t')\n",
|
|
" before = clean(separated[6]).split()[-1]\n",
|
|
" new_line = predict(before, model, vocab)\n",
|
|
" fid.write(new_line + '\\n')"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## CLASSES"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"class Bigrams(IterableDataset):\n",
|
|
" def __init__(self, text_file, vocabulary_size):\n",
|
|
" self.vocab = build_vocab_from_iterator(\n",
|
|
" get_word_lines_from_file(text_file),\n",
|
|
" max_tokens = vocabulary_size,\n",
|
|
" specials = ['<unk>'])\n",
|
|
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
|
" self.vocabulary_size = vocabulary_size\n",
|
|
" self.text_file = text_file\n",
|
|
"\n",
|
|
" def __iter__(self):\n",
|
|
" return look_ahead_iterator(\n",
|
|
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
|
"\n",
|
|
"class SimpleBigramNeuralLanguageModel(nn.Module):\n",
|
|
" def __init__(self, vocabulary_size, embedding_size):\n",
|
|
" super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
|
|
" self.model = nn.Sequential(\n",
|
|
" nn.Embedding(vocabulary_size, embedding_size),\n",
|
|
" nn.Linear(embedding_size, vocabulary_size),\n",
|
|
" nn.Softmax()\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" return self.model(x)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## PARAMETERS"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"vocab_size = 30000\n",
|
|
"embed_size = 1000\n",
|
|
"batch_size = 5000\n",
|
|
"device = 'mps'\n",
|
|
"path_to_training_file = './train/in.tsv.xz'\n",
|
|
"path_to_model_file = 'model_neural_network.bin'\n",
|
|
"folder_dev_0, file_dev_0 = 'dev-0', 'in.tsv.xz'\n",
|
|
"folder_test_a, file_test_a = 'test-A', 'in.tsv.xz'\n",
|
|
"path_to_vocabulary_file = 'vocabulary_neural_network.pickle'"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## VOCAB"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"vocab = build_vocab_from_iterator(\n",
|
|
" get_word_lines_from_file(path_to_training_file),\n",
|
|
" max_tokens = vocab_size,\n",
|
|
" specials = ['<unk>'])\n",
|
|
"\n",
|
|
"with open(path_to_vocabulary_file, 'wb') as handle:\n",
|
|
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## TRAIN MODEL"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataset = Bigrams(path_to_training_file, vocab_size)\n",
|
|
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
|
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
|
"optimizer = torch.optim.Adam(model.parameters())\n",
|
|
"criterion = torch.nn.NLLLoss()\n",
|
|
"\n",
|
|
"model.train()\n",
|
|
"step = 0\n",
|
|
"for x, y in data:\n",
|
|
" x = x.to(device)\n",
|
|
" y = y.to(device)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" ypredicted = model(x)\n",
|
|
" loss = criterion(torch.log(ypredicted), y)\n",
|
|
" if step % 100 == 0:\n",
|
|
" print(step, loss)\n",
|
|
" step += 1\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
"torch.save(model.state_dict(), path_to_model_file)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## LOAD MODEL AND VOCAB"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"with open(path_to_vocabulary_file, 'rb') as handle:\n",
|
|
" vocab = pickle.load(handle)\n",
|
|
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
|
"model.load_state_dict(torch.load(path_to_model_file))\n",
|
|
"model.eval()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## CREATE OUTPUTS FILES"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### DEV-0"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"predicition_for_file(model, vocab, folder_dev_0, file_dev_0)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### TEST-A"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"predicition_for_file(model, vocab, folder_test_a, file_test_a)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|