ipynb file
This commit is contained in:
parent
5ebc5d3f12
commit
b7ee0fb834
534
RNN.ipynb
Normal file
534
RNN.ipynb
Normal file
@ -0,0 +1,534 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "ae9d73b0-9e7a-4259-aa04-2d3176864d40",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from torch import nn, optim\n",
|
||||||
|
"from torch.utils.data import DataLoader\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"import regex as re\n",
|
||||||
|
"import itertools\n",
|
||||||
|
"from itertools import islice"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"id": "ae22808c-8957-4d38-94bc-8f9cfc5f8b99",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CUDA Available: True\n",
|
||||||
|
"CUDA Device Name: NVIDIA GeForce RTX 3050\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
|
"cuda_available = torch.cuda.is_available()\n",
|
||||||
|
"print(f\"CUDA Available: {cuda_available}\")\n",
|
||||||
|
"if cuda_available:\n",
|
||||||
|
" print(f\"CUDA Device Name: {torch.cuda.get_device_name(0)}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"id": "41daea76-75a5-4098-b5ae-b770d3aa9e1b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"device = 'cuda'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "fa76fb6d-c5cf-4711-a65e-8ec004e3b6fc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/train/train.txt\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "e40859e9-88e4-4ff5-a78c-bb11b3822fd3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Dataset(torch.utils.data.Dataset):\n",
|
||||||
|
" def __init__(\n",
|
||||||
|
" self,\n",
|
||||||
|
" sequence_length,\n",
|
||||||
|
" train_path,\n",
|
||||||
|
" max_vocab_size=20000\n",
|
||||||
|
" ):\n",
|
||||||
|
" self.sequence_length = sequence_length\n",
|
||||||
|
" self.train_path = train_path\n",
|
||||||
|
" self.max_vocab_size = max_vocab_size\n",
|
||||||
|
"\n",
|
||||||
|
" self.words = self.load()\n",
|
||||||
|
" self.uniq_words = self.get_uniq_words()\n",
|
||||||
|
"\n",
|
||||||
|
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
|
||||||
|
" self.index_to_word[len(self.index_to_word)] = '<UNK>'\n",
|
||||||
|
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
|
||||||
|
" self.word_to_index['<UNK>'] = len(self.word_to_index)\n",
|
||||||
|
"\n",
|
||||||
|
" self.words_indexes = [self.word_to_index.get(w, self.word_to_index['<UNK>']) for w in self.words]\n",
|
||||||
|
"\n",
|
||||||
|
" def load(self):\n",
|
||||||
|
" with open(self.train_path, 'r', encoding='utf-8') as f_in:\n",
|
||||||
|
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
|
||||||
|
" text = ' '.join(text).lower()\n",
|
||||||
|
" text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')\n",
|
||||||
|
" text = re.sub(r'\\n', ' ', text)\n",
|
||||||
|
" text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n",
|
||||||
|
" text = re.sub(r'\\s+', ' ', text)\n",
|
||||||
|
" text = re.sub(r'\\p{P}', '', text)\n",
|
||||||
|
" text = text.split(' ')\n",
|
||||||
|
" return text\n",
|
||||||
|
"\n",
|
||||||
|
" def get_uniq_words(self):\n",
|
||||||
|
" word_counts = Counter(self.words)\n",
|
||||||
|
" most_common_words = word_counts.most_common(self.max_vocab_size)\n",
|
||||||
|
" return [word for word, _ in most_common_words]\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.words_indexes) - self.sequence_length\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, index):\n",
|
||||||
|
" # Get the sequence\n",
|
||||||
|
" sequence = self.words_indexes[index:index+self.sequence_length]\n",
|
||||||
|
" # Split the sequence into x and y\n",
|
||||||
|
" x = sequence[:2] + sequence[-2:]\n",
|
||||||
|
" y = sequence[len(sequence) // 2]\n",
|
||||||
|
" return torch.tensor(x), torch.tensor(y)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "bf0efaba-86a2-4368-a31d-de7d08a759a0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_dataset = Dataset(5, train_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 51,
|
||||||
|
"id": "7aa7bd72-5978-484e-b541-36f737f22b0d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(tensor([ 14, 110, 3, 28]), tensor(208))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 51,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"train_dataset[420]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 53,
|
||||||
|
"id": "2a13298c-e0dd-4181-9093-7cec414b5b79",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['at', 'last', 'to', 'tho']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 53,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"[train_dataset.index_to_word[x] for x in [ 14, 110, 3, 28]]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 54,
|
||||||
|
"id": "192c4d6d-3fc1-4687-9ce4-b1a8cbea7d82",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['come']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 54,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"[train_dataset.index_to_word[208]]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 52,
|
||||||
|
"id": "3f0cd5b3-3937-4ad8-a9f8-766d27ad9d70",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(tensor([ 218, 104, 8207, 3121]), tensor(20000))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 52,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"train_dataset[21237]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "b1302c90-d77e-49e4-8b9d-9a8aeca675b0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"\n",
|
||||||
|
"class Model(nn.Module):\n",
|
||||||
|
" def __init__(self, vocab_size, lstm_size=128, embedding_dim=128, num_layers=3, dropout=0.2):\n",
|
||||||
|
" super(Model, self).__init__()\n",
|
||||||
|
" self.lstm_size = lstm_size\n",
|
||||||
|
" self.embedding_dim = embedding_dim\n",
|
||||||
|
" self.num_layers = num_layers\n",
|
||||||
|
" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"\n",
|
||||||
|
" self.embedding = nn.Embedding(\n",
|
||||||
|
" num_embeddings=vocab_size,\n",
|
||||||
|
" embedding_dim=self.embedding_dim,\n",
|
||||||
|
" )\n",
|
||||||
|
" self.lstm = nn.LSTM(\n",
|
||||||
|
" input_size=self.embedding_dim,\n",
|
||||||
|
" hidden_size=self.lstm_size,\n",
|
||||||
|
" num_layers=self.num_layers,\n",
|
||||||
|
" dropout=dropout,\n",
|
||||||
|
" )\n",
|
||||||
|
" self.fc1 = nn.Linear(self.lstm_size, 256) \n",
|
||||||
|
" self.fc2 = nn.Linear(256, vocab_size)\n",
|
||||||
|
" self.softmax = nn.Softmax(dim=1)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, x, prev_state=None):\n",
|
||||||
|
" x = x.to(self.device)\n",
|
||||||
|
" embed = self.embedding(x)\n",
|
||||||
|
" embed = embed.transpose(0, 1)\n",
|
||||||
|
" \n",
|
||||||
|
" if prev_state is None:\n",
|
||||||
|
" prev_state = self.init_state(x.size(0))\n",
|
||||||
|
" \n",
|
||||||
|
" output, state = self.lstm(embed, prev_state)\n",
|
||||||
|
" logits = self.fc1(output[-1])\n",
|
||||||
|
" logits = self.fc2(logits)\n",
|
||||||
|
" probabilities = self.softmax(logits)\n",
|
||||||
|
" return probabilities\n",
|
||||||
|
"\n",
|
||||||
|
" def init_state(self, batch_size):\n",
|
||||||
|
" return (torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device),\n",
|
||||||
|
" torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 105,
|
||||||
|
"id": "93a29618-3283-4ad5-881f-48c84839ceeb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train(dataset, model, max_epochs, batch_size):\n",
|
||||||
|
" model.train()\n",
|
||||||
|
"\n",
|
||||||
|
" dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True)\n",
|
||||||
|
" criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
" optimizer = optim.Adam(model.parameters())\n",
|
||||||
|
"\n",
|
||||||
|
" for epoch in range(max_epochs):\n",
|
||||||
|
" for batch, (x, y) in enumerate(dataloader):\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" x = x.to(device, non_blocking=True)\n",
|
||||||
|
" y = y.to(device, non_blocking=True)\n",
|
||||||
|
"\n",
|
||||||
|
" y_pred = model(x)\n",
|
||||||
|
" loss = criterion(torch.log(y_pred), y)\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" if batch % 500 == 0:\n",
|
||||||
|
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 106,
|
||||||
|
"id": "2315e67d-a315-44b5-bddf-5ab4bed1e727",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'epoch': 0, 'update in batch': 0, '/': 16679, 'loss': 9.917818069458008}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 500, '/': 16679, 'loss': 6.078440189361572}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 1000, '/': 16679, 'loss': 5.651369571685791}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 1500, '/': 16679, 'loss': 5.4341654777526855}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 2000, '/': 16679, 'loss': 5.383695602416992}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 2500, '/': 16679, 'loss': 5.225739479064941}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 3000, '/': 16679, 'loss': 5.282474517822266}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 3500, '/': 16679, 'loss': 5.092397689819336}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 4000, '/': 16679, 'loss': 4.940906047821045}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 4500, '/': 16679, 'loss': 4.908115863800049}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 5000, '/': 16679, 'loss': 5.092423439025879}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 5500, '/': 16679, 'loss': 4.979565620422363}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 6000, '/': 16679, 'loss': 4.8268022537231445}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 6500, '/': 16679, 'loss': 4.7172017097473145}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 7000, '/': 16679, 'loss': 4.781315326690674}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 7500, '/': 16679, 'loss': 5.0033040046691895}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 8000, '/': 16679, 'loss': 4.663774013519287}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 8500, '/': 16679, 'loss': 4.710158348083496}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 9000, '/': 16679, 'loss': 4.817586898803711}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 9500, '/': 16679, 'loss': 4.655371189117432}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 10000, '/': 16679, 'loss': 4.679412841796875}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 10500, '/': 16679, 'loss': 4.544621467590332}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 11000, '/': 16679, 'loss': 4.816493511199951}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 11500, '/': 16679, 'loss': 4.627770900726318}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 12000, '/': 16679, 'loss': 4.525866985321045}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 12500, '/': 16679, 'loss': 4.739295959472656}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 13000, '/': 16679, 'loss': 4.6095709800720215}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 13500, '/': 16679, 'loss': 4.7243266105651855}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 14000, '/': 16679, 'loss': 4.557321071624756}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 14500, '/': 16679, 'loss': 4.830319404602051}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 15000, '/': 16679, 'loss': 4.536618709564209}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 15500, '/': 16679, 'loss': 4.605734825134277}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 16000, '/': 16679, 'loss': 4.605676651000977}\n",
|
||||||
|
"{'epoch': 0, 'update in batch': 16500, '/': 16679, 'loss': 4.614283084869385}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = Model(vocab_size = len(train_dataset.uniq_words) + 1).to(device)\n",
|
||||||
|
"train(train_dataset, model, 1, 8192)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 107,
|
||||||
|
"id": "8acf3dc2-f3fe-4a2a-bdf9-82a18acb1bd1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"torch.save(model.state_dict(), 'model.pth')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"id": "5e60d5b3-019d-4d63-b794-59e1356bc45e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = Model(20001).to(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "7e55b0b2-cdda-4c37-8979-0400f9973461",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<All keys matched successfully>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model.load_state_dict(torch.load('model.pth'))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "e842b192-8e10-438c-b8ee-781a4a7a875c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def clean(text):\n",
|
||||||
|
" text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')\n",
|
||||||
|
" text = re.sub(r'\\n', ' ', text)\n",
|
||||||
|
" text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n",
|
||||||
|
" text = re.sub(r'\\s+', ' ', text)\n",
|
||||||
|
" text = re.sub(r'\\p{P}', '', text)\n",
|
||||||
|
" text = text.strip()\n",
|
||||||
|
" return text"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "f20f8fdc-194e-415a-8343-6f590abe1166",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_words(words, model, dataset, n=20):\n",
|
||||||
|
" ixs = [dataset.word_to_index.get(word, dataset.word_to_index['<UNK>']) for word in words]\n",
|
||||||
|
" ixs = torch.tensor(ixs).unsqueeze(0).to(model.device)\n",
|
||||||
|
"\n",
|
||||||
|
" out = model(ixs)\n",
|
||||||
|
" top = torch.topk(out[0], n)\n",
|
||||||
|
" top_indices = top.indices.tolist()\n",
|
||||||
|
" top_probs = top.values.tolist()\n",
|
||||||
|
" top_words = [dataset.index_to_word[idx] for idx in top_indices]\n",
|
||||||
|
" return list(zip(top_words, top_probs))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "22ebafa5-d21f-4208-9aad-a4c4d90134c4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def f_out(left, right, model, dataset):\n",
|
||||||
|
" left = clean(left)\n",
|
||||||
|
" right = clean(right)\n",
|
||||||
|
" words = left.split(' ')[-2:] + right.split(' ')[:2]\n",
|
||||||
|
" words = get_words(words, model, dataset)\n",
|
||||||
|
"\n",
|
||||||
|
" probs_sum = 0\n",
|
||||||
|
" output = ''\n",
|
||||||
|
" for word, prob in words:\n",
|
||||||
|
" if word == \"<UNK>\":\n",
|
||||||
|
" continue\n",
|
||||||
|
" probs_sum += prob\n",
|
||||||
|
" output += f\"{word}:{prob} \"\n",
|
||||||
|
" output += f\":{1-probs_sum}\"\n",
|
||||||
|
"\n",
|
||||||
|
" return output"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "1dc64cee-a9a5-44d4-92da-82e1b7f8fdc4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def create_out(input_path, model, dataset, output_path):\n",
|
||||||
|
" lines = []\n",
|
||||||
|
" with open(input_path, encoding='utf-8') as f:\n",
|
||||||
|
" for line in f:\n",
|
||||||
|
" columns = line.split('\\t')\n",
|
||||||
|
" left = columns[6]\n",
|
||||||
|
" right = columns[7]\n",
|
||||||
|
" lines.append((left, right))\n",
|
||||||
|
"\n",
|
||||||
|
" with open(output_path, 'w', encoding='utf-8') as output_file:\n",
|
||||||
|
" for left, right in lines:\n",
|
||||||
|
" result = f_out(left, right, model, dataset)\n",
|
||||||
|
" output_file.write(result + '\\n')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "348a77c1-8ff1-40bb-a243-3b702c119c2c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dev_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/in.tsv\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"id": "9377c725-3309-4590-89d2-444057ae2b80",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"create_out(dev_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/out.tsv')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"id": "50f47d4a-762f-48b2-9c19-f385d9822886",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"test_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/in.tsv\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"id": "18aa1059-88ed-4c32-af88-80a4de4be6c9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"create_out(test_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/out.tsv')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user