simple nn

This commit is contained in:
Adrian Charkiewicz 2023-05-28 02:55:39 +02:00
parent 1f742b4802
commit 39e6371df8
4 changed files with 18368 additions and 6 deletions

423
MOJ_7ipynb.ipynb Normal file
View File

@ -0,0 +1,423 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"machine_shape": "hm"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3dV_4SJ2xY_C",
"outputId": "c1039907-474a-427a-ee13-62a3b7b4a693"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount(\"/content/gdrive\")"
]
},
{
"cell_type": "code",
"source": [
"# %env DATA_DIR=/content/gdrive/MyDrive/data_gralinski\n",
"DATA_DIR=\"/content/gdrive/MyDrive/data_gralinski/\""
],
"metadata": {
"id": "VwdW1Qm3x9-N"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import itertools\n",
"import lzma\n",
"import regex as re\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import IterableDataset, DataLoader\n",
"from torchtext.vocab import build_vocab_from_iterator"
],
"metadata": {
"id": "irsty5KcyYkR"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def clean_line(line: str):\n",
" separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ')\n",
" suffix = separated[7].replace(r'\\n', ' ')\n",
" return prefix + ' ' + suffix"
],
"metadata": {
"id": "LXXtiKW3yY5J"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_words_from_line(line):\n",
" line = clean_line(line)\n",
" for word in line.split():\n",
" yield word"
],
"metadata": {
"id": "y9r0wmD3ycIi"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_word_lines_from_file(file_name):\n",
" with lzma.open(file_name, mode='rt', encoding='utf-8') as fid:\n",
" for line in fid:\n",
" yield get_words_from_line(line)"
],
"metadata": {
"id": "HE3YfiHkycKt"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def look_ahead_iterator(gen):\n",
" prev = None\n",
" for item in gen:\n",
" if prev is not None:\n",
" yield (prev, item)\n",
" prev = item"
],
"metadata": {
"id": "lvHvJV6XycNZ"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def prediction(word: str) -> str:\n",
" ixs = torch.tensor(vocab.forward([word])).to(device)\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], 5)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" zipped = list(zip(top_words, top_probs))\n",
" for index, element in enumerate(zipped):\n",
" unk = None\n",
" if '<unk>' in element:\n",
" unk = zipped.pop(index)\n",
" zipped.append(('', unk[1]))\n",
" break\n",
" if unk is None:\n",
" zipped[-1] = ('', zipped[-1][1])\n",
" return ' '.join([f'{x[0]}:{x[1]}' for x in zipped])"
],
"metadata": {
"id": "sOKeZN9cycP-"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def create_outputs(folder_name):\n",
" print(f'Creating outputs in {folder_name}')\n",
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
" with open(f'{folder_name}/out2.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
" for line in fid:\n",
" separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ').split()[-1]\n",
" output_line = prediction(prefix)\n",
" f.write(output_line + '\\n')"
],
"metadata": {
"id": "MN_RftZNycSB"
},
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class Bigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file),\n",
" max_tokens=vocabulary_size,\n",
" specials=['<unk>'])\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))"
],
"metadata": {
"id": "n9wIsbLEycUd"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class SimpleBigramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Embedding(vocabulary_size, embedding_size),\n",
" nn.Linear(embedding_size, vocabulary_size),\n",
" nn.Softmax()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)"
],
"metadata": {
"id": "l490B5KFycXj"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"vocab_size = 15000\n",
"embed_size = 300\n",
"batch_size = 3000\n",
"device = 'cuda'\n",
"path_to_train = DATA_DIR+'train/in.tsv.xz'\n",
"path_to_model = 'model.bin'"
],
"metadata": {
"id": "mMC84-OzycZ5"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(path_to_train),\n",
" max_tokens=vocab_size,\n",
" specials=['<unk>']\n",
")\n",
"\n",
"vocab.set_default_index(vocab['<unk>'])"
],
"metadata": {
"id": "Fsvv3QJl7kWN"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_dataset = Bigrams(path_to_train, vocab_size)"
],
"metadata": {
"id": "UK73WsKnB8ZP"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()\n",
"\n",
"model.train()\n",
"step = 0\n",
"for x, y in data:\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 5000 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Q_Gz-stnqHPg",
"outputId": "4b6a3751-21da-48a9-afcb-802b21f7274b"
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" input = module(input)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 tensor(9.8217, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(5.3191, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(5.0986, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(5.2346, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"20000 tensor(5.4174, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"25000 tensor(5.1875, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"30000 tensor(5.1892, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"35000 tensor(5.0867, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"40000 tensor(5.1812, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"45000 tensor(5.1327, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"torch.save(model.state_dict(), path_to_model)"
],
"metadata": {
"id": "WTl82y44qHR_"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"model.load_state_dict(torch.load(path_to_model))\n",
"model.eval()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jViNAbxIqHUr",
"outputId": "3c30bed7-8eaf-4e7b-eb89-89bdcf95abe7"
},
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"SimpleBigramNeuralLanguageModel(\n",
" (model): Sequential(\n",
" (0): Embedding(15000, 300)\n",
" (1): Linear(in_features=300, out_features=15000, bias=True)\n",
" (2): Softmax(dim=None)\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"create_outputs(DATA_DIR+'dev-0')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LkHFiPvoqHW8",
"outputId": "d4d12906-5f4c-4df6-a22e-95a4fbbd9850"
},
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Creating outputs in /content/gdrive/MyDrive/data_gralinski/dev-0\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" input = module(input)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"create_outputs(DATA_DIR+'test-A')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QoiHDZ_ZqO8F",
"outputId": "0fc59359-ce83-4a43-d11b-8d43280df3a1"
},
"execution_count": 31,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Creating outputs in /content/gdrive/MyDrive/data_gralinski/test-A\n"
]
}
]
}
]
}

10519
dev-0/out.tsv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,16 @@
description: trigram model description: neurl-network bigram
tags: tags:
- neural-network - neural-network
- trigram - bigram
params: params:
epochs: 1 epochs: 1
learning-rate: 0.0001 learning-rate: 0.001
vocab_size: 40000 unwanted-params:
embed_size: 300 - model-file
hidden_size: 256 - vocab-file
param-files:
- "*.yaml"
- config/*.yaml
links:
- title: "repository"
url: "https://git.wmi.amu.edu.pl/s444354/challenging-america-word-gap-prediction.git"

7414
test-A/out.tsv Normal file

File diff suppressed because it is too large Load Diff