2 lines
8.4 KiB
Plaintext
2 lines
8.4 KiB
Plaintext
|
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":2401,"status":"ok","timestamp":1683752152796,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"iaKSJMGa3242"},"outputs":[],"source":["import itertools\n","import lzma\n","import numpy as np\n","import regex as re\n","import torch\n","from torch import nn\n","from torch.utils.data import IterableDataset, DataLoader\n","from torchtext.vocab import build_vocab_from_iterator"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152797,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"OddXJEKo3244"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","%cd /content/drive/MyDrive/america"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152797,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"KTGHWX-73244"},"outputs":[],"source":["def get_line(line: str):\n"," parts = line.split('\\t')\n"," prefix = parts[6].replace(r'\\n', ' ')\n"," suffix = parts[7].replace(r'\\n', ' ')\n"," return prefix + ' ' + suffix\n","\n","def read_words(line):\n"," line = get_line(line)\n"," for word in line.split():\n"," yield word\n","\n","def get_words_from_file(path):\n"," with lzma.open(path, mode='rt', encoding='utf-8') as f:\n"," for line in f:\n"," yield read_words(line)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152800,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"K_JhskDE3247"},"outputs":[],"source":["class SimpleTrigramNeuralLanguageModel(nn.Module):\n"," def __init__(self, vocabulary_size, embedding_size, hidden_size):\n"," super(SimpleTrigramNeuralLanguageModel, self).__init__()\n"," self.embedding_size = embedding_size\n"," self.embedding = nn.Embedding(vocabulary_size, embedding_size)\n"," self.lin1 = nn.Linear(2 * embedding_size, hidden_size)\n"," self.rel = nn.ReLU()\n"," self.lin2 = nn.Linear(hidden_size, vocabulary_size)\n"," self.sm = nn.Softmax()\n","\n"," def forward(self, x):\n"," x = self.embedding(x).view((-1, 2 * self.embedding_size))\n"," x = self.lin1(x)\n"," x = self.rel(x)\n"," x = self.lin2(x)\n"," return self.sm(x)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def get_context(gen):\n"," items = [None, None] + list(gen)\n"," for i in range(2, len(items)):\n"," if items[i-2] is not None:\n"," yield np.asarray(items[i-2:i+1])"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152800,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"JG29cM8w3247"},"outputs":[],"source":["class Trigrams(IterableDataset):\n"," def __init__(self, text_file, vocabulary_size):\n"," self.vocab = build_vocab_from_iterator(\n"," get_words_from_file(text_file),\n"," max_tokens=vocabulary_size,\n"," specials=['<unk>'])\n"," self.vocab.set_default_index(self.vocab['<unk>'])\n"," self.vocabulary_size = vocabulary_size\n"," self.text_file = text_file\n","\n"," def __iter__(self):\n"," return get_context(\n"," (self.vocab[t] for t in itertools.chain.from_iterable(get_words_from_file(self.text_file))))"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":11,"status":"ok","timestamp":1683752152799,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"LG8VwtS-3246"},"outputs":[],"source":["def train_model(lr):\n"," model = SimpleTrigramNeuralLanguageMod
|