diff --git a/RNN.ipynb b/RNN.ipynb new file mode 100644 index 0000000..043606d --- /dev/null +++ b/RNN.ipynb @@ -0,0 +1,534 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "ae9d73b0-9e7a-4259-aa04-2d3176864d40", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn, optim\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from collections import Counter\n", + "import regex as re\n", + "import itertools\n", + "from itertools import islice" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "ae22808c-8957-4d38-94bc-8f9cfc5f8b99", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA Available: True\n", + "CUDA Device Name: NVIDIA GeForce RTX 3050\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "cuda_available = torch.cuda.is_available()\n", + "print(f\"CUDA Available: {cuda_available}\")\n", + "if cuda_available:\n", + " print(f\"CUDA Device Name: {torch.cuda.get_device_name(0)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "41daea76-75a5-4098-b5ae-b770d3aa9e1b", + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cuda'" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fa76fb6d-c5cf-4711-a65e-8ec004e3b6fc", + "metadata": {}, + "outputs": [], + "source": [ + "train_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/train/train.txt\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e40859e9-88e4-4ff5-a78c-bb11b3822fd3", + "metadata": {}, + "outputs": [], + "source": [ + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(\n", + " self,\n", + " sequence_length,\n", + " train_path,\n", + " max_vocab_size=20000\n", + " ):\n", + " self.sequence_length = sequence_length\n", + " self.train_path = train_path\n", + " self.max_vocab_size = max_vocab_size\n", + "\n", + " self.words = self.load()\n", + " self.uniq_words = self.get_uniq_words()\n", + "\n", + " self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n", + " self.index_to_word[len(self.index_to_word)] = ''\n", + " self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n", + " self.word_to_index[''] = len(self.word_to_index)\n", + "\n", + " self.words_indexes = [self.word_to_index.get(w, self.word_to_index['']) for w in self.words]\n", + "\n", + " def load(self):\n", + " with open(self.train_path, 'r', encoding='utf-8') as f_in:\n", + " text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n", + " text = ' '.join(text).lower()\n", + " text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')\n", + " text = re.sub(r'\\n', ' ', text)\n", + " text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n", + " text = re.sub(r'\\s+', ' ', text)\n", + " text = re.sub(r'\\p{P}', '', text)\n", + " text = text.split(' ')\n", + " return text\n", + "\n", + " def get_uniq_words(self):\n", + " word_counts = Counter(self.words)\n", + " most_common_words = word_counts.most_common(self.max_vocab_size)\n", + " return [word for word, _ in most_common_words]\n", + "\n", + " def __len__(self):\n", + " return len(self.words_indexes) - self.sequence_length\n", + "\n", + " def __getitem__(self, index):\n", + " # Get the sequence\n", + " sequence = self.words_indexes[index:index+self.sequence_length]\n", + " # Split the sequence into x and y\n", + " x = sequence[:2] + sequence[-2:]\n", + " y = sequence[len(sequence) // 2]\n", + " return torch.tensor(x), torch.tensor(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "bf0efaba-86a2-4368-a31d-de7d08a759a0", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = Dataset(5, train_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "7aa7bd72-5978-484e-b541-36f737f22b0d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([ 14, 110, 3, 28]), tensor(208))" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[420]" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "2a13298c-e0dd-4181-9093-7cec414b5b79", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['at', 'last', 'to', 'tho']" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[train_dataset.index_to_word[x] for x in [ 14, 110, 3, 28]]" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "192c4d6d-3fc1-4687-9ce4-b1a8cbea7d82", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['come']" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[train_dataset.index_to_word[208]]" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "3f0cd5b3-3937-4ad8-a9f8-766d27ad9d70", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([ 218, 104, 8207, 3121]), tensor(20000))" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[21237]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b1302c90-d77e-49e4-8b9d-9a8aeca675b0", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class Model(nn.Module):\n", + " def __init__(self, vocab_size, lstm_size=128, embedding_dim=128, num_layers=3, dropout=0.2):\n", + " super(Model, self).__init__()\n", + " self.lstm_size = lstm_size\n", + " self.embedding_dim = embedding_dim\n", + " self.num_layers = num_layers\n", + " self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + " self.embedding = nn.Embedding(\n", + " num_embeddings=vocab_size,\n", + " embedding_dim=self.embedding_dim,\n", + " )\n", + " self.lstm = nn.LSTM(\n", + " input_size=self.embedding_dim,\n", + " hidden_size=self.lstm_size,\n", + " num_layers=self.num_layers,\n", + " dropout=dropout,\n", + " )\n", + " self.fc1 = nn.Linear(self.lstm_size, 256) \n", + " self.fc2 = nn.Linear(256, vocab_size)\n", + " self.softmax = nn.Softmax(dim=1)\n", + " \n", + " def forward(self, x, prev_state=None):\n", + " x = x.to(self.device)\n", + " embed = self.embedding(x)\n", + " embed = embed.transpose(0, 1)\n", + " \n", + " if prev_state is None:\n", + " prev_state = self.init_state(x.size(0))\n", + " \n", + " output, state = self.lstm(embed, prev_state)\n", + " logits = self.fc1(output[-1])\n", + " logits = self.fc2(logits)\n", + " probabilities = self.softmax(logits)\n", + " return probabilities\n", + "\n", + " def init_state(self, batch_size):\n", + " return (torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device),\n", + " torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "id": "93a29618-3283-4ad5-881f-48c84839ceeb", + "metadata": {}, + "outputs": [], + "source": [ + "def train(dataset, model, max_epochs, batch_size):\n", + " model.train()\n", + "\n", + " dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True)\n", + " criterion = nn.CrossEntropyLoss()\n", + " optimizer = optim.Adam(model.parameters())\n", + "\n", + " for epoch in range(max_epochs):\n", + " for batch, (x, y) in enumerate(dataloader):\n", + " optimizer.zero_grad()\n", + " x = x.to(device, non_blocking=True)\n", + " y = y.to(device, non_blocking=True)\n", + "\n", + " y_pred = model(x)\n", + " loss = criterion(torch.log(y_pred), y)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch % 500 == 0:\n", + " print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "2315e67d-a315-44b5-bddf-5ab4bed1e727", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'epoch': 0, 'update in batch': 0, '/': 16679, 'loss': 9.917818069458008}\n", + "{'epoch': 0, 'update in batch': 500, '/': 16679, 'loss': 6.078440189361572}\n", + "{'epoch': 0, 'update in batch': 1000, '/': 16679, 'loss': 5.651369571685791}\n", + "{'epoch': 0, 'update in batch': 1500, '/': 16679, 'loss': 5.4341654777526855}\n", + "{'epoch': 0, 'update in batch': 2000, '/': 16679, 'loss': 5.383695602416992}\n", + "{'epoch': 0, 'update in batch': 2500, '/': 16679, 'loss': 5.225739479064941}\n", + "{'epoch': 0, 'update in batch': 3000, '/': 16679, 'loss': 5.282474517822266}\n", + "{'epoch': 0, 'update in batch': 3500, '/': 16679, 'loss': 5.092397689819336}\n", + "{'epoch': 0, 'update in batch': 4000, '/': 16679, 'loss': 4.940906047821045}\n", + "{'epoch': 0, 'update in batch': 4500, '/': 16679, 'loss': 4.908115863800049}\n", + "{'epoch': 0, 'update in batch': 5000, '/': 16679, 'loss': 5.092423439025879}\n", + "{'epoch': 0, 'update in batch': 5500, '/': 16679, 'loss': 4.979565620422363}\n", + "{'epoch': 0, 'update in batch': 6000, '/': 16679, 'loss': 4.8268022537231445}\n", + "{'epoch': 0, 'update in batch': 6500, '/': 16679, 'loss': 4.7172017097473145}\n", + "{'epoch': 0, 'update in batch': 7000, '/': 16679, 'loss': 4.781315326690674}\n", + "{'epoch': 0, 'update in batch': 7500, '/': 16679, 'loss': 5.0033040046691895}\n", + "{'epoch': 0, 'update in batch': 8000, '/': 16679, 'loss': 4.663774013519287}\n", + "{'epoch': 0, 'update in batch': 8500, '/': 16679, 'loss': 4.710158348083496}\n", + "{'epoch': 0, 'update in batch': 9000, '/': 16679, 'loss': 4.817586898803711}\n", + "{'epoch': 0, 'update in batch': 9500, '/': 16679, 'loss': 4.655371189117432}\n", + "{'epoch': 0, 'update in batch': 10000, '/': 16679, 'loss': 4.679412841796875}\n", + "{'epoch': 0, 'update in batch': 10500, '/': 16679, 'loss': 4.544621467590332}\n", + "{'epoch': 0, 'update in batch': 11000, '/': 16679, 'loss': 4.816493511199951}\n", + "{'epoch': 0, 'update in batch': 11500, '/': 16679, 'loss': 4.627770900726318}\n", + "{'epoch': 0, 'update in batch': 12000, '/': 16679, 'loss': 4.525866985321045}\n", + "{'epoch': 0, 'update in batch': 12500, '/': 16679, 'loss': 4.739295959472656}\n", + "{'epoch': 0, 'update in batch': 13000, '/': 16679, 'loss': 4.6095709800720215}\n", + "{'epoch': 0, 'update in batch': 13500, '/': 16679, 'loss': 4.7243266105651855}\n", + "{'epoch': 0, 'update in batch': 14000, '/': 16679, 'loss': 4.557321071624756}\n", + "{'epoch': 0, 'update in batch': 14500, '/': 16679, 'loss': 4.830319404602051}\n", + "{'epoch': 0, 'update in batch': 15000, '/': 16679, 'loss': 4.536618709564209}\n", + "{'epoch': 0, 'update in batch': 15500, '/': 16679, 'loss': 4.605734825134277}\n", + "{'epoch': 0, 'update in batch': 16000, '/': 16679, 'loss': 4.605676651000977}\n", + "{'epoch': 0, 'update in batch': 16500, '/': 16679, 'loss': 4.614283084869385}\n" + ] + } + ], + "source": [ + "model = Model(vocab_size = len(train_dataset.uniq_words) + 1).to(device)\n", + "train(train_dataset, model, 1, 8192)" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "8acf3dc2-f3fe-4a2a-bdf9-82a18acb1bd1", + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model.state_dict(), 'model.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "5e60d5b3-019d-4d63-b794-59e1356bc45e", + "metadata": {}, + "outputs": [], + "source": [ + "model = Model(20001).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7e55b0b2-cdda-4c37-8979-0400f9973461", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.load_state_dict(torch.load('model.pth'))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e842b192-8e10-438c-b8ee-781a4a7a875c", + "metadata": {}, + "outputs": [], + "source": [ + "def clean(text):\n", + " text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')\n", + " text = re.sub(r'\\n', ' ', text)\n", + " text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n", + " text = re.sub(r'\\s+', ' ', text)\n", + " text = re.sub(r'\\p{P}', '', text)\n", + " text = text.strip()\n", + " return text" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f20f8fdc-194e-415a-8343-6f590abe1166", + "metadata": {}, + "outputs": [], + "source": [ + "def get_words(words, model, dataset, n=20):\n", + " ixs = [dataset.word_to_index.get(word, dataset.word_to_index['']) for word in words]\n", + " ixs = torch.tensor(ixs).unsqueeze(0).to(model.device)\n", + "\n", + " out = model(ixs)\n", + " top = torch.topk(out[0], n)\n", + " top_indices = top.indices.tolist()\n", + " top_probs = top.values.tolist()\n", + " top_words = [dataset.index_to_word[idx] for idx in top_indices]\n", + " return list(zip(top_words, top_probs))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "22ebafa5-d21f-4208-9aad-a4c4d90134c4", + "metadata": {}, + "outputs": [], + "source": [ + "def f_out(left, right, model, dataset):\n", + " left = clean(left)\n", + " right = clean(right)\n", + " words = left.split(' ')[-2:] + right.split(' ')[:2]\n", + " words = get_words(words, model, dataset)\n", + "\n", + " probs_sum = 0\n", + " output = ''\n", + " for word, prob in words:\n", + " if word == \"\":\n", + " continue\n", + " probs_sum += prob\n", + " output += f\"{word}:{prob} \"\n", + " output += f\":{1-probs_sum}\"\n", + "\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1dc64cee-a9a5-44d4-92da-82e1b7f8fdc4", + "metadata": {}, + "outputs": [], + "source": [ + "def create_out(input_path, model, dataset, output_path):\n", + " lines = []\n", + " with open(input_path, encoding='utf-8') as f:\n", + " for line in f:\n", + " columns = line.split('\\t')\n", + " left = columns[6]\n", + " right = columns[7]\n", + " lines.append((left, right))\n", + "\n", + " with open(output_path, 'w', encoding='utf-8') as output_file:\n", + " for left, right in lines:\n", + " result = f_out(left, right, model, dataset)\n", + " output_file.write(result + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "348a77c1-8ff1-40bb-a243-3b702c119c2c", + "metadata": {}, + "outputs": [], + "source": [ + "dev_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/in.tsv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "9377c725-3309-4590-89d2-444057ae2b80", + "metadata": {}, + "outputs": [], + "source": [ + "create_out(dev_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/out.tsv')" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "50f47d4a-762f-48b2-9c19-f385d9822886", + "metadata": {}, + "outputs": [], + "source": [ + "test_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/in.tsv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "18aa1059-88ed-4c32-af88-80a4de4be6c9", + "metadata": {}, + "outputs": [], + "source": [ + "create_out(test_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/out.tsv')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}