2 lines
8.4 KiB
Plaintext
2 lines
8.4 KiB
Plaintext
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":2401,"status":"ok","timestamp":1683752152796,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"iaKSJMGa3242"},"outputs":[],"source":["import itertools\n","import lzma\n","import numpy as np\n","import regex as re\n","import torch\n","from torch import nn\n","from torch.utils.data import IterableDataset, DataLoader\n","from torchtext.vocab import build_vocab_from_iterator"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152797,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"OddXJEKo3244"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","%cd /content/drive/MyDrive/america"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152797,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"KTGHWX-73244"},"outputs":[],"source":["def get_line(line: str):\n"," parts = line.split('\\t')\n"," prefix = parts[6].replace(r'\\n', ' ')\n"," suffix = parts[7].replace(r'\\n', ' ')\n"," return prefix + ' ' + suffix\n","\n","def read_words(line):\n"," line = get_line(line)\n"," for word in line.split():\n"," yield word\n","\n","def get_words_from_file(path):\n"," with lzma.open(path, mode='rt', encoding='utf-8') as f:\n"," for line in f:\n"," yield read_words(line)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152800,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"K_JhskDE3247"},"outputs":[],"source":["class SimpleTrigramNeuralLanguageModel(nn.Module):\n"," def __init__(self, vocabulary_size, embedding_size, hidden_size):\n"," super(SimpleTrigramNeuralLanguageModel, self).__init__()\n"," self.embedding_size = embedding_size\n"," self.embedding = nn.Embedding(vocabulary_size, embedding_size)\n"," self.lin1 = nn.Linear(2 * embedding_size, hidden_size)\n"," self.rel = nn.ReLU()\n"," self.lin2 = nn.Linear(hidden_size, vocabulary_size)\n"," self.sm = nn.Softmax()\n","\n"," def forward(self, x):\n"," x = self.embedding(x).view((-1, 2 * self.embedding_size))\n"," x = self.lin1(x)\n"," x = self.rel(x)\n"," x = self.lin2(x)\n"," return self.sm(x)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def get_context(gen):\n"," items = [None, None] + list(gen)\n"," for i in range(2, len(items)):\n"," if items[i-2] is not None:\n"," yield np.asarray(items[i-2:i+1])"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683752152800,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"JG29cM8w3247"},"outputs":[],"source":["class Trigrams(IterableDataset):\n"," def __init__(self, text_file, vocabulary_size):\n"," self.vocab = build_vocab_from_iterator(\n"," get_words_from_file(text_file),\n"," max_tokens=vocabulary_size,\n"," specials=['<unk>'])\n"," self.vocab.set_default_index(self.vocab['<unk>'])\n"," self.vocabulary_size = vocabulary_size\n"," self.text_file = text_file\n","\n"," def __iter__(self):\n"," return get_context(\n"," (self.vocab[t] for t in itertools.chain.from_iterable(get_words_from_file(self.text_file))))"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":11,"status":"ok","timestamp":1683752152799,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"LG8VwtS-3246"},"outputs":[],"source":["def train_model(lr):\n"," model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n"," data = DataLoader(train_dataset, batch_size=batch_size)\n"," optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n"," criterion = torch.nn.NLLLoss()\n","\n"," model.train()\n"," step = 0\n"," for batch in data:\n"," x = batch[:, :2]\n"," y = batch[:, 2]\n"," x = x.to(device)\n"," y = y.to(device)\n"," optimizer.zero_grad()\n"," ypredicted = model(x)\n"," loss = criterion(torch.log(ypredicted), y)\n"," if step % 100 == 0:\n"," print(step, loss)\n"," step += 1\n"," loss.backward()\n","\n"," torch.nn.utils.clip_grad_norm_(model.parameters(), 10)\n"," optimizer.step()\n","\n"," torch.save(model.state_dict(), model_path)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def prediction(words, model, top) -> str:\n"," words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n"," ixs = torch.tensor(words_tensor).view(-1).to(device)\n"," out = model(ixs)\n"," top_values, top_indices = torch.topk(out[0], top)\n"," top_probs = top_values.tolist()\n"," top_words = vocab.lookup_tokens(top_indices.tolist())\n"," unk_index = top_words.index('<unk>') if '<unk>' in top_words else -1\n"," if unk_index != -1:\n"," unk_prob = top_probs[unk_index]\n"," top_words.pop(unk_index)\n"," top_probs.pop(unk_index)\n"," top_words.append('')\n"," top_probs.append(unk_prob)\n"," else:\n"," top_words[-1] = ''\n"," return ' '.join([f'{x[0]}:{x[1]}' for x in zip(top_words, top_probs)])\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def save_outputs(folder_name, model, top):\n"," input_file_path = f'{folder_name}/in.tsv.xz'\n"," output_file_path = f'{folder_name}/out-top={top}.tsv'\n"," with lzma.open(input_file_path, mode='rt', encoding='utf-8') as input_file:\n"," with open(output_file_path, 'w', encoding='utf-8', newline='\\n') as output_file:\n"," for line in input_file:\n"," separated = line.split('\\t')\n"," prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n"," output_line = prediction(prefix, model, top)\n"," output_file.write(output_line + '\\n')\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":11,"status":"ok","timestamp":1683752152800,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"L2A391Dh3247"},"outputs":[],"source":["vocab_size = 15000\n","embed_size = 200\n","hidden_size = 100\n","batch_size = 3000\n","learning_rate = 0.0001\n","device = 'cuda'\n","train_path = 'train/in.tsv.xz'\n","model_path = 'model1.bin'"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":112029,"status":"ok","timestamp":1683752266514,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"PFKOeLem3248"},"outputs":[],"source":["vocab = build_vocab_from_iterator(\n"," get_words_from_file(train_path),\n"," max_tokens=vocab_size,\n"," specials=['<unk>']\n",")\n","\n","vocab.set_default_index(vocab['<unk>'])"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1896770,"status":"ok","timestamp":1683754270096,"user":{"displayName":"Sebastian Wałęsa","userId":"16297137502741045838"},"user_tz":-120},"id":"8WpWyk9a3249","outputId":"18ec7cda-35f3-4b91-fb6b-c2bf24edcb73"},"outputs":[],"source":["train_dataset = Trigrams(train_path, vocab_size)\n","train_model(lr=learning_rate)\n","model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n","model.load_state_dict(torch.load(model_path))\n","model.eval()\n","for top in [100, 200, 300]:\n"," save_outputs('dev-0', model, top)\n"," save_outputs('test-A', model, top)"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.8"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":0}
|