challenging-america-word-ga.../simple_neural_network.ipynb

348 lines
8.9 KiB
Plaintext
Raw Normal View History

2023-04-27 21:39:28 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## IMPORTS"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
"outputs": [],
2023-04-27 21:39:28 +02:00
"source": [
"import regex as re\n",
"import sys\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"import lzma\n",
"from torch.utils.data import IterableDataset\n",
"import itertools\n",
"from torch import nn\n",
"import torch\n",
"import pickle\n",
"from torch.utils.data import DataLoader\n",
"\n",
"print(torch.backends.mps.is_available())\n",
"print(torch.backends.mps.is_built())"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## FUNCTIONS"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
2023-04-27 21:39:28 +02:00
"outputs": [],
"source": [
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" yield '<s>'\n",
" for t in line.split(' '):\n",
" yield t\n",
" yield '</s>'\n",
"\n",
"def get_word_lines_from_file(file_name):\n",
" n = 0\n",
" with lzma.open(file_name, 'r') as fh:\n",
" for line in fh:\n",
" n += 1\n",
" if n % 1000 == 0:\n",
" print(n ,file=sys.stderr)\n",
" yield get_words_from_line(line.decode('utf-8'))\n",
"\n",
"def look_ahead_iterator(gen):\n",
" prev = None\n",
" for item in gen:\n",
" if prev is not None:\n",
" yield (prev, item)\n",
" prev = item\n",
"\n",
"def clean(text):\n",
" text = str(text).lower().replace('-\\\\n', '').replace('\\\\n', ' ').replace('-', '').replace('\\'s', ' is').replace('\\'re', ' are').replace('\\'m', ' am').replace('\\'ve', ' have').replace('\\'ll', ' will')\n",
" text = re.sub(r'\\p{P}', '', text)\n",
" return text\n",
"\n",
"def predict(word, model, vocab):\n",
" try:\n",
" ixs = torch.tensor(vocab.forward([word])).to(device)\n",
" except:\n",
" ixs = torch.tensor(vocab.forward(['<unk>'])).to(device)\n",
" word = '<unk>'\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], 300)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" prob_list = list(zip(top_words, top_probs))\n",
" for index, element in enumerate(prob_list):\n",
" unk = None\n",
" if '<unk>' in element:\n",
" unk = prob_list.pop(index)\n",
" prob_list.append(('', unk[1]))\n",
" break\n",
" if unk is None:\n",
" prob_list[-1] = ('', prob_list[-1][1])\n",
" return ' '.join([f'{x[0]}:{x[1]}' for x in prob_list])\n",
"\n",
"def predicition_for_file(model, vocab, folder, file):\n",
" print('=' * 10, f' do prediction for {folder}/{file} ', '=' * 10)\n",
" with lzma.open(f'{folder}/in.tsv.xz', mode='rt', encoding='utf-8') as f:\n",
" with open(f'{folder}/out.tsv', 'w', encoding='utf-8') as fid:\n",
"\n",
" for line in f:\n",
" separated = line.split('\\t')\n",
2023-06-03 06:59:06 +02:00
" head = clean(separated[6]).split()[-1]\n",
" new_line = predict(head, model, vocab)\n",
2023-04-27 21:39:28 +02:00
" fid.write(new_line + '\\n')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## CLASSES"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
2023-04-27 21:39:28 +02:00
"outputs": [],
"source": [
"class Bigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file),\n",
" max_tokens = vocabulary_size,\n",
" specials = ['<unk>'])\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
"\n",
"class SimpleBigramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Embedding(vocabulary_size, embedding_size),\n",
" nn.Linear(embedding_size, vocabulary_size),\n",
" nn.Softmax()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## PARAMETERS"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
2023-04-27 21:39:28 +02:00
"outputs": [],
"source": [
"vocab_size = 30000\n",
"embed_size = 1000\n",
"batch_size = 5000\n",
"device = 'mps'\n",
"path_to_training_file = './train/in.tsv.xz'\n",
"path_to_model_file = 'model_neural_network.bin'\n",
"folder_dev_0, file_dev_0 = 'dev-0', 'in.tsv.xz'\n",
"folder_test_a, file_test_a = 'test-A', 'in.tsv.xz'\n",
"path_to_vocabulary_file = 'vocabulary_neural_network.pickle'"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## VOCAB"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
"outputs": [],
2023-04-27 21:39:28 +02:00
"source": [
"vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(path_to_training_file),\n",
" max_tokens = vocab_size,\n",
" specials = ['<unk>'])\n",
"\n",
"with open(path_to_vocabulary_file, 'wb') as handle:\n",
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## TRAIN MODEL"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
"outputs": [],
2023-04-27 21:39:28 +02:00
"source": [
"train_dataset = Bigrams(path_to_training_file, vocab_size)\n",
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()\n",
"\n",
"model.train()\n",
"step = 0\n",
"for x, y in data:\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 100 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"torch.save(model.state_dict(), path_to_model_file)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## LOAD MODEL AND VOCAB"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
"outputs": [],
2023-04-27 21:39:28 +02:00
"source": [
"with open(path_to_vocabulary_file, 'rb') as handle:\n",
" vocab = pickle.load(handle)\n",
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"model.load_state_dict(torch.load(path_to_model_file))\n",
"model.eval()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## CREATE OUTPUTS FILES"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### DEV-0"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
"outputs": [],
2023-04-27 21:39:28 +02:00
"source": [
"predicition_for_file(model, vocab, folder_dev_0, file_dev_0)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### TEST-A"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2023-04-27 21:40:39 +02:00
"execution_count": null,
"outputs": [],
2023-04-27 21:39:28 +02:00
"source": [
"predicition_for_file(model, vocab, folder_test_a, file_test_a)"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}