2024-05-27 12:59:11 +02:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 7,
|
|
|
|
"id": "ae9d73b0-9e7a-4259-aa04-2d3176864d40",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import torch\n",
|
|
|
|
"from torch import nn, optim\n",
|
|
|
|
"from torch.utils.data import DataLoader\n",
|
|
|
|
"import numpy as np\n",
|
|
|
|
"from collections import Counter\n",
|
|
|
|
"import regex as re\n",
|
|
|
|
"import itertools\n",
|
|
|
|
"from itertools import islice"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 21,
|
|
|
|
"id": "ae22808c-8957-4d38-94bc-8f9cfc5f8b99",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"CUDA Available: True\n",
|
|
|
|
"CUDA Device Name: NVIDIA GeForce RTX 3050\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"import torch\n",
|
|
|
|
"\n",
|
|
|
|
"cuda_available = torch.cuda.is_available()\n",
|
|
|
|
"print(f\"CUDA Available: {cuda_available}\")\n",
|
|
|
|
"if cuda_available:\n",
|
|
|
|
" print(f\"CUDA Device Name: {torch.cuda.get_device_name(0)}\")"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 22,
|
|
|
|
"id": "41daea76-75a5-4098-b5ae-b770d3aa9e1b",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"device = 'cuda'"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 8,
|
|
|
|
"id": "fa76fb6d-c5cf-4711-a65e-8ec004e3b6fc",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"train_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/train/train.txt\""
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 9,
|
|
|
|
"id": "e40859e9-88e4-4ff5-a78c-bb11b3822fd3",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"class Dataset(torch.utils.data.Dataset):\n",
|
|
|
|
" def __init__(\n",
|
|
|
|
" self,\n",
|
|
|
|
" sequence_length,\n",
|
|
|
|
" train_path,\n",
|
|
|
|
" max_vocab_size=20000\n",
|
|
|
|
" ):\n",
|
|
|
|
" self.sequence_length = sequence_length\n",
|
|
|
|
" self.train_path = train_path\n",
|
|
|
|
" self.max_vocab_size = max_vocab_size\n",
|
|
|
|
"\n",
|
|
|
|
" self.words = self.load()\n",
|
|
|
|
" self.uniq_words = self.get_uniq_words()\n",
|
|
|
|
"\n",
|
|
|
|
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
|
|
|
|
" self.index_to_word[len(self.index_to_word)] = '<UNK>'\n",
|
|
|
|
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
|
|
|
|
" self.word_to_index['<UNK>'] = len(self.word_to_index)\n",
|
|
|
|
"\n",
|
|
|
|
" self.words_indexes = [self.word_to_index.get(w, self.word_to_index['<UNK>']) for w in self.words]\n",
|
|
|
|
"\n",
|
|
|
|
" def load(self):\n",
|
|
|
|
" with open(self.train_path, 'r', encoding='utf-8') as f_in:\n",
|
|
|
|
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
|
|
|
|
" text = ' '.join(text).lower()\n",
|
|
|
|
" text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')\n",
|
|
|
|
" text = re.sub(r'\\n', ' ', text)\n",
|
|
|
|
" text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n",
|
|
|
|
" text = re.sub(r'\\s+', ' ', text)\n",
|
|
|
|
" text = re.sub(r'\\p{P}', '', text)\n",
|
|
|
|
" text = text.split(' ')\n",
|
|
|
|
" return text\n",
|
|
|
|
"\n",
|
|
|
|
" def get_uniq_words(self):\n",
|
|
|
|
" word_counts = Counter(self.words)\n",
|
|
|
|
" most_common_words = word_counts.most_common(self.max_vocab_size)\n",
|
|
|
|
" return [word for word, _ in most_common_words]\n",
|
|
|
|
"\n",
|
|
|
|
" def __len__(self):\n",
|
|
|
|
" return len(self.words_indexes) - self.sequence_length\n",
|
|
|
|
"\n",
|
|
|
|
" def __getitem__(self, index):\n",
|
|
|
|
" # Get the sequence\n",
|
|
|
|
" sequence = self.words_indexes[index:index+self.sequence_length]\n",
|
|
|
|
" # Split the sequence into x and y\n",
|
|
|
|
" x = sequence[:2] + sequence[-2:]\n",
|
|
|
|
" y = sequence[len(sequence) // 2]\n",
|
|
|
|
" return torch.tensor(x), torch.tensor(y)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 10,
|
|
|
|
"id": "bf0efaba-86a2-4368-a31d-de7d08a759a0",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"train_dataset = Dataset(5, train_path)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 51,
|
|
|
|
"id": "7aa7bd72-5978-484e-b541-36f737f22b0d",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"(tensor([ 14, 110, 3, 28]), tensor(208))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 51,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"train_dataset[420]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 53,
|
|
|
|
"id": "2a13298c-e0dd-4181-9093-7cec414b5b79",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"['at', 'last', 'to', 'tho']"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 53,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"[train_dataset.index_to_word[x] for x in [ 14, 110, 3, 28]]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 54,
|
|
|
|
"id": "192c4d6d-3fc1-4687-9ce4-b1a8cbea7d82",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"['come']"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 54,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"[train_dataset.index_to_word[208]]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 52,
|
|
|
|
"id": "3f0cd5b3-3937-4ad8-a9f8-766d27ad9d70",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"(tensor([ 218, 104, 8207, 3121]), tensor(20000))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 52,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"train_dataset[21237]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 12,
|
|
|
|
"id": "b1302c90-d77e-49e4-8b9d-9a8aeca675b0",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import torch\n",
|
|
|
|
"import torch.nn as nn\n",
|
|
|
|
"\n",
|
|
|
|
"class Model(nn.Module):\n",
|
|
|
|
" def __init__(self, vocab_size, lstm_size=128, embedding_dim=128, num_layers=3, dropout=0.2):\n",
|
|
|
|
" super(Model, self).__init__()\n",
|
|
|
|
" self.lstm_size = lstm_size\n",
|
|
|
|
" self.embedding_dim = embedding_dim\n",
|
|
|
|
" self.num_layers = num_layers\n",
|
|
|
|
" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
|
|
"\n",
|
|
|
|
" self.embedding = nn.Embedding(\n",
|
|
|
|
" num_embeddings=vocab_size,\n",
|
|
|
|
" embedding_dim=self.embedding_dim,\n",
|
|
|
|
" )\n",
|
|
|
|
" self.lstm = nn.LSTM(\n",
|
|
|
|
" input_size=self.embedding_dim,\n",
|
|
|
|
" hidden_size=self.lstm_size,\n",
|
|
|
|
" num_layers=self.num_layers,\n",
|
|
|
|
" dropout=dropout,\n",
|
|
|
|
" )\n",
|
|
|
|
" self.fc1 = nn.Linear(self.lstm_size, 256) \n",
|
|
|
|
" self.fc2 = nn.Linear(256, vocab_size)\n",
|
|
|
|
" self.softmax = nn.Softmax(dim=1)\n",
|
|
|
|
" \n",
|
|
|
|
" def forward(self, x, prev_state=None):\n",
|
|
|
|
" x = x.to(self.device)\n",
|
|
|
|
" embed = self.embedding(x)\n",
|
|
|
|
" embed = embed.transpose(0, 1)\n",
|
|
|
|
" \n",
|
|
|
|
" if prev_state is None:\n",
|
|
|
|
" prev_state = self.init_state(x.size(0))\n",
|
|
|
|
" \n",
|
|
|
|
" output, state = self.lstm(embed, prev_state)\n",
|
|
|
|
" logits = self.fc1(output[-1])\n",
|
|
|
|
" logits = self.fc2(logits)\n",
|
|
|
|
" probabilities = self.softmax(logits)\n",
|
|
|
|
" return probabilities\n",
|
|
|
|
"\n",
|
|
|
|
" def init_state(self, batch_size):\n",
|
|
|
|
" return (torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device),\n",
|
|
|
|
" torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device))\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 105,
|
|
|
|
"id": "93a29618-3283-4ad5-881f-48c84839ceeb",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def train(dataset, model, max_epochs, batch_size):\n",
|
|
|
|
" model.train()\n",
|
|
|
|
"\n",
|
|
|
|
" dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True)\n",
|
|
|
|
" criterion = nn.CrossEntropyLoss()\n",
|
|
|
|
" optimizer = optim.Adam(model.parameters())\n",
|
|
|
|
"\n",
|
|
|
|
" for epoch in range(max_epochs):\n",
|
|
|
|
" for batch, (x, y) in enumerate(dataloader):\n",
|
|
|
|
" optimizer.zero_grad()\n",
|
|
|
|
" x = x.to(device, non_blocking=True)\n",
|
|
|
|
" y = y.to(device, non_blocking=True)\n",
|
|
|
|
"\n",
|
|
|
|
" y_pred = model(x)\n",
|
|
|
|
" loss = criterion(torch.log(y_pred), y)\n",
|
|
|
|
"\n",
|
|
|
|
" loss.backward()\n",
|
|
|
|
" optimizer.step()\n",
|
|
|
|
"\n",
|
|
|
|
" if batch % 500 == 0:\n",
|
|
|
|
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 106,
|
|
|
|
"id": "2315e67d-a315-44b5-bddf-5ab4bed1e727",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"{'epoch': 0, 'update in batch': 0, '/': 16679, 'loss': 9.917818069458008}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 500, '/': 16679, 'loss': 6.078440189361572}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 1000, '/': 16679, 'loss': 5.651369571685791}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 1500, '/': 16679, 'loss': 5.4341654777526855}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 2000, '/': 16679, 'loss': 5.383695602416992}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 2500, '/': 16679, 'loss': 5.225739479064941}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 3000, '/': 16679, 'loss': 5.282474517822266}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 3500, '/': 16679, 'loss': 5.092397689819336}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 4000, '/': 16679, 'loss': 4.940906047821045}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 4500, '/': 16679, 'loss': 4.908115863800049}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 5000, '/': 16679, 'loss': 5.092423439025879}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 5500, '/': 16679, 'loss': 4.979565620422363}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 6000, '/': 16679, 'loss': 4.8268022537231445}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 6500, '/': 16679, 'loss': 4.7172017097473145}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 7000, '/': 16679, 'loss': 4.781315326690674}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 7500, '/': 16679, 'loss': 5.0033040046691895}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 8000, '/': 16679, 'loss': 4.663774013519287}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 8500, '/': 16679, 'loss': 4.710158348083496}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 9000, '/': 16679, 'loss': 4.817586898803711}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 9500, '/': 16679, 'loss': 4.655371189117432}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 10000, '/': 16679, 'loss': 4.679412841796875}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 10500, '/': 16679, 'loss': 4.544621467590332}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 11000, '/': 16679, 'loss': 4.816493511199951}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 11500, '/': 16679, 'loss': 4.627770900726318}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 12000, '/': 16679, 'loss': 4.525866985321045}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 12500, '/': 16679, 'loss': 4.739295959472656}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 13000, '/': 16679, 'loss': 4.6095709800720215}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 13500, '/': 16679, 'loss': 4.7243266105651855}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 14000, '/': 16679, 'loss': 4.557321071624756}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 14500, '/': 16679, 'loss': 4.830319404602051}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 15000, '/': 16679, 'loss': 4.536618709564209}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 15500, '/': 16679, 'loss': 4.605734825134277}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 16000, '/': 16679, 'loss': 4.605676651000977}\n",
|
|
|
|
"{'epoch': 0, 'update in batch': 16500, '/': 16679, 'loss': 4.614283084869385}\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"model = Model(vocab_size = len(train_dataset.uniq_words) + 1).to(device)\n",
|
|
|
|
"train(train_dataset, model, 1, 8192)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 107,
|
|
|
|
"id": "8acf3dc2-f3fe-4a2a-bdf9-82a18acb1bd1",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"torch.save(model.state_dict(), 'model.pth')"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 23,
|
|
|
|
"id": "5e60d5b3-019d-4d63-b794-59e1356bc45e",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"model = Model(20001).to(device)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 24,
|
|
|
|
"id": "7e55b0b2-cdda-4c37-8979-0400f9973461",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"<All keys matched successfully>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 24,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"model.load_state_dict(torch.load('model.pth'))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
|
|
|
"id": "e842b192-8e10-438c-b8ee-781a4a7a875c",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def clean(text):\n",
|
|
|
|
" text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')\n",
|
|
|
|
" text = re.sub(r'\\n', ' ', text)\n",
|
|
|
|
" text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n",
|
|
|
|
" text = re.sub(r'\\s+', ' ', text)\n",
|
|
|
|
" text = re.sub(r'\\p{P}', '', text)\n",
|
|
|
|
" text = text.strip()\n",
|
|
|
|
" return text"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 3,
|
|
|
|
"id": "f20f8fdc-194e-415a-8343-6f590abe1166",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def get_words(words, model, dataset, n=20):\n",
|
|
|
|
" ixs = [dataset.word_to_index.get(word, dataset.word_to_index['<UNK>']) for word in words]\n",
|
|
|
|
" ixs = torch.tensor(ixs).unsqueeze(0).to(model.device)\n",
|
|
|
|
"\n",
|
|
|
|
" out = model(ixs)\n",
|
|
|
|
" top = torch.topk(out[0], n)\n",
|
|
|
|
" top_indices = top.indices.tolist()\n",
|
|
|
|
" top_probs = top.values.tolist()\n",
|
|
|
|
" top_words = [dataset.index_to_word[idx] for idx in top_indices]\n",
|
|
|
|
" return list(zip(top_words, top_probs))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 4,
|
|
|
|
"id": "22ebafa5-d21f-4208-9aad-a4c4d90134c4",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def f_out(left, right, model, dataset):\n",
|
|
|
|
" left = clean(left)\n",
|
|
|
|
" right = clean(right)\n",
|
|
|
|
" words = left.split(' ')[-2:] + right.split(' ')[:2]\n",
|
|
|
|
" words = get_words(words, model, dataset)\n",
|
|
|
|
"\n",
|
|
|
|
" probs_sum = 0\n",
|
|
|
|
" output = ''\n",
|
|
|
|
" for word, prob in words:\n",
|
|
|
|
" if word == \"<UNK>\":\n",
|
|
|
|
" continue\n",
|
|
|
|
" probs_sum += prob\n",
|
|
|
|
" output += f\"{word}:{prob} \"\n",
|
|
|
|
" output += f\":{1-probs_sum}\"\n",
|
|
|
|
"\n",
|
|
|
|
" return output"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 6,
|
|
|
|
"id": "1dc64cee-a9a5-44d4-92da-82e1b7f8fdc4",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def create_out(input_path, model, dataset, output_path):\n",
|
|
|
|
" lines = []\n",
|
|
|
|
" with open(input_path, encoding='utf-8') as f:\n",
|
|
|
|
" for line in f:\n",
|
|
|
|
" columns = line.split('\\t')\n",
|
|
|
|
" left = columns[6]\n",
|
|
|
|
" right = columns[7]\n",
|
|
|
|
" lines.append((left, right))\n",
|
|
|
|
"\n",
|
|
|
|
" with open(output_path, 'w', encoding='utf-8') as output_file:\n",
|
|
|
|
" for left, right in lines:\n",
|
|
|
|
" result = f_out(left, right, model, dataset)\n",
|
|
|
|
" output_file.write(result + '\\n')"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 19,
|
|
|
|
"id": "348a77c1-8ff1-40bb-a243-3b702c119c2c",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"dev_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/in.tsv\""
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 25,
|
|
|
|
"id": "9377c725-3309-4590-89d2-444057ae2b80",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"create_out(dev_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/out.tsv')"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 26,
|
|
|
|
"id": "50f47d4a-762f-48b2-9c19-f385d9822886",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"test_path = \"C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/in.tsv\""
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 27,
|
|
|
|
"id": "18aa1059-88ed-4c32-af88-80a4de4be6c9",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"create_out(test_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/out.tsv')"
|
|
|
|
]
|
2024-05-27 13:43:35 +02:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
|
|
|
"id": "426de26f-b72e-41dc-a63c-a7956d3b1655",
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8MAAAAvCAYAAADHGxhSAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsMAAA7DAcdvqGQAACNdSURBVHhe7Z1vbBRHlsDfQMjAiVgmEcntWhgjiM1KURISCd3E/mCDFFb27if2fJLNLbqLdGcIfMhKlhb2+MStOclS8gECvpP2TuwC0lqbT3tYRySwP9jxKlKWTRRpbR8WYOTNZq3EA0gHzh/m6tWf7uqe7qrq6ZnxGL8fGjzTf6qrXr2qrtfvVXVm48aNBagBMi0ZyJ4GWAMF9i8D3xwF+Hq6JrJG1BibfnAfnltsh6mJj+SW6pL2+sud/7TUYv4zmVfhO71jsBlOw8zFn8GDQnzfYct/5nu/gRdfnbGm40qtyWul6x+RjkrXf9L2s6H1D9C8a7v4cesIfPzf58X3GiJJ/5IGrJvGpisw9+7fwmKFrhGm3Pqw2vsXm/4nkY/Su+xHdXDzj6Xpg6191UL7q1b7qlVWU/nLPb5Kw/379+Hu3bv8+xr+fxV44l8zsP6fM/JXEGUIf8sM4P97HeAh+7uW/X4ik+H71r+fgb96H7zPBpaWzhqWrr4fP/q1Miyddf8VPEbtV/uy3w/+3sA+a2Kuj58NbNs6tk8/R9+HeY/aF97vSoblT09DL5+p/Lb8q/16fjDP2dD+ovOlfPTj1b5wPdvqJwmZZwfgOTZY+PyD38st5QFvUDtbX5W/4kl7/Urlv1qE85/JHIRtR4Kyw5vrS0f+AN99trQ6riS1Kn9X/UtLUf2x3zuP/AY2Bdo/1mmw/sRx9+GlAwOwoejY5a1/lN1LmDf5qYYcq0259KMW9f/BxMvw8ZmnYOb6rNxCrFRqtX+tFZZDPrb2Re2PIKpoDMeBhtQTPwUovO17ggvs71fvMWP4n/hPzlf7MtxQfsD+ftMQNMaQwodiv/o8/HeRFk//P5lx/Tt/n74/TIZd8wn2d+kfAR4VCjwvD18v8HO+Ytd49J7MB9uG+VXpPzHvp73E8r6ObdPz9+3bfv6+/pDtvyKMfRfQmFz/E3beUT+Nr9h2ZYwjKl/qo8pny78N/Xz1wXSAlRflg6w9yf5j9cfTZXmE/cG8IXH1k5T1z3cCXD9VtafmYdJef7nzn5ai/G9ugWye3US37edGEj7h3LSNtYG82F0tCoWP4E8XnoKPLxw3Pm1c6fJPS6nlF+cdgQXohE2b5UakBup/8b9ZvbPBHH4+efcILO0ag23fc+tbVxurXf9LxbV/SYvQ5R89PvcHIgDJJ5pqta9aZbWXvxbIPPXqUwXllVXGEXoDw9vQs/vkbvEdw5i/3gfwjaw0ZRCua/ArEY2/pf8pCENuv79dgcbRg39hKX2f7W9jRhT7rqO2P/wV8Lx8o18vlD+8RnZLcRoI5u1JZniisYb5CaPyvubXzMi7WXwtHZTBE3eChhzPJxqqEfLgaaLRK7+r60fJNw5b/hEs/5Psr83AjMy/zEtAvvKajyLyZ9qHqLLjwwd1HVP9KDKZnXDowhB0wzD0HTgHUxHyRy9U05tvwdLwLvjTX8R+HnKxfQYW6o/C5nqAe1ePAOw9A3UwCwvsuM8WXpFhRmo7IvZhGoEQIZ2IcKGo6yt4PvYyIUqWrheHQUXmPyIMKhxG4oXQsPIp7l0Nhk2h56ixSf7IB0NtsIxb4Q24Db/wyqrnzyV9JDL/zw5Ay+vs+Dy7yX/I5A0/h6bnp2Fp21sA7/vHBfIHfhigS/lt+Q/UoSHMy1R/wfwxQjI0ydcmPzxXhcX5xwodXHz+ulX/itOPkl+0fuvE1l93M3yuhWWGj9Pr6M9PX+d14emOY/1XE13eLtjqD4mrf1v/o8pfqv649E+29qGIrH+H/Jv0TxEoH5Kg/Sh4WTe9E9Hvxl8/7pywDgTzl6z9uPQvLuWLIqrurfLV9S/yfF9/RPvuhHt6nYe26bLy0zPob0T+kLj+1ZZHJE39mHA536X92Oo3mH9GxDGm+0/c+MHPfzssvapkWFx+l/qJaysK0/648rvI11b/pvZVjvRdiNW/MvTvNv0ylV8Rpx8Kl/qPIq7OMb3iPkHuDKVvKx8SzB8jon3E4VK/sfVn0Z9gmPQMM8rmmfHSzn9zMux7hm37lu1D0JhZO+579bhnEw08Zvgg3JuqeUbxowy3R8wgwt+6VxI/yjDKbGXHsLTRiFKhthjCu5YZpoUGtp0fFSIiz3EUmEC+Zfld+5OCFwodh/JQRxnCcWD+gaWvn4PXLLD88X1paRYy+BYN4hoA63oNK6/JiF/D6q1wW/4oI+tfewvqbr0TuAlwmo4yZa+DuVsAdXvxZtMOC/ntkH1G7meI7XUyHAhg8+4f8+0qRAjPxQasPExRHVLc9UVHtYM1MJE+fvSOQBGbfwvrX/sFbM6zTlLljX30Tp53BODvn8sfheYuUT5FdtcYNG8bgZl36+CT4dMAu37hhbHa0leY8r84ewPqnn8F1j/fDEv/+ynb4ssfO8v6WV82c7f2QWPvzwPhtjZM+XcN84rLf1h+n1wNNjabfF3l53Xq/NiXeT5s+qfOqbvp75u5vqNIfnWsowd2cxD7mX6/XizfUvUPNu+HuvorkJ8CePjlDchKL7COqf6rCQ70k4Yhpm5flv4njf649k+m9qGIrX9D/l30L237MWG7Puoj1LcU6SOy9KXQAZf+J+7+gNj6lzTl8zxC8lysuyW5T2FKv+h8lD0bUHv6u/Ae3GN1iW1Tgd7JbIQeRPVPiGv/Hadf9V0qTYzcQN0Suqzy6FY/9v7NhO18U/ux1a9N/xVx8nEZP9TtxQG9ln9NP8txfzXhot+m9mOrf5f7d5r0bVjll7J/R9KMX2z6kab+bf2n6/gjSfuJax9xVKv/WMONxd8xgf+NMG7RKF3Lvj/6NftIAw8NWt0r+S0zXovYzYzJUOI2vGsxwxdDbTPvMcU4mgE3FQ6S2V0IzEnVDd9vmOH94G12LWYQc2M7NOcYwX3ozYzzvsaxZov8koC1f8/yqz1sMLKNHSu/mliz3y9/OIS8XGB9rdvP5PkruSECrMc1rGxfh/TdVD9IoTAFZ3vbob33bIxXGMMvt8O92V/KLRr50/BnHKgvss6E3Ww+WxCbs0/7A4B7V/2neKYOIA7T9eu372ON842im5yOMf8uNP0wMK9TgU+b65tmYeFDP938h2xAFT5efxK3MF004IpLXxGb/2eaIYt/p34LS9uOwV9vmoFFKX8Fdvb6jSk/m6wz5NjybyEu/1Hy03GWr0V+AC/4A82IBy2xSENUN+4efvAO3KsPhivrT0mj9qfRPzF4/q14EszqOZC2Q/1XA7zh4nzhF7uPlhaGmKZ9Gfqf8umPBUv7MNa/qf+06F/Z2k8cNv3/YsYrKw6KcF41lnWD5kVw6X9KvT+kLp+FJOnjsU17AeYu+8eisfzZR1e8B1jxehDfP7nIL75/FXWhtmN+Fm/OQnbTC/w34lY/5v7NhvX8mPZjk3/U/ihM7c9l/BDIf0g/y3J/jcFV/+Laj0v9u1DJ9K3yK0f/bumfTdj0I1X92/pPx/FHkvaThGr2H3zOcOE/mEHaIH80M4OGfdc9keFFlDYww1GHG8vMkH3ySryxaaLAjMIMXnOM/UDjjxlTj8QuZ8JzUsNGbYH95h7pfRl4xAz3sMGIYd0FZujZvMdhHt2RXywoQxw/69jvB/9Q8B42FMmXfU/yYEF539Xnaya/rOa5LxfcK2ww4jGCYB2T7Tf/xvIky6aw1Y+Vncf4SnvYKS0LMddXjVV5IWJJkX9s7Pxp3Jv3xCJBPzgo9zBwziZsh83dcp80CLKwAzYEOqtp0VExCoXzcFN78m9MX2HKP0/7POTz+yC7+B48lJsV6K3jiy+p/GnhPs4Y8u9EXP65/G7AgzgDzkG+LvLL7joDPPQw6U0Bjc38TEimn8JSKPLBSon6h/qtD+JQ9vlbQU+Trf5thPXjpdCCXi7o84Y/3zQWrcMxlKV9xVEm/bFiax+l9j82/StD+zFiuz4OvOqbYT28Aps23YClTfvZd2QWlr7gX8rT/8SRtnw2EqR
|
|
|
|
"text/plain": [
|
|
|
|
"<IPython.core.display.Image object>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 2,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"from IPython.display import Image\n",
|
|
|
|
"Image(filename='C:/Users/Mauri/Desktop/Zrzut ekranu 2024-05-27 134020.png')"
|
|
|
|
]
|
2024-05-27 12:59:11 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3 (ipykernel)",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.11.9"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 5
|
|
|
|
}
|