final solution
This commit is contained in:
parent
eacc27a0d2
commit
96313d4915
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
.ipynb_checkpoints/
|
||||
processed_train.txt
|
||||
model/
|
10519
dev-0/out.tsv
Normal file
10519
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
789
run.ipynb
789
run.ipynb
@ -3,22 +3,24 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "56fb2fdb",
|
||||
"id": "03de852a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import regex as re\n",
|
||||
"import torch\n",
|
||||
"import csv\n",
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from collections import Counter"
|
||||
"from gensim.models import Word2Vec\n",
|
||||
"from nltk.tokenize import word_tokenize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "74f269c9",
|
||||
"id": "73497953",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -29,81 +31,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c939f600",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Dataset(torch.utils.data.Dataset):\n",
|
||||
" def __init__(self, sequence_length, file_path):\n",
|
||||
" self.file_path = file_path\n",
|
||||
" self.sequence_length = sequence_length\n",
|
||||
" self.words = self.load()\n",
|
||||
" self.uniq_words = self.get_uniq_words()\n",
|
||||
"\n",
|
||||
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
|
||||
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
|
||||
" \n",
|
||||
" self.index_to_word[-1] = '<UNK>'\n",
|
||||
" self.word_to_index['<UNK>'] = -1\n",
|
||||
"\n",
|
||||
" self.words_indexes = [self.word_to_index[w] if w in self.uniq_words else self.word_to_index['<UNK>'] for w in self.words]\n",
|
||||
"\n",
|
||||
" def load(self):\n",
|
||||
" with open(self.file_path, 'r') as f_in:\n",
|
||||
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
|
||||
" text = ' '.join(text).lower()\n",
|
||||
" text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n",
|
||||
" text = text.split(' ')\n",
|
||||
" return text\n",
|
||||
" \n",
|
||||
" def get_uniq_words(self):\n",
|
||||
" word_counts = Counter(self.words).most_common(250000)\n",
|
||||
" word_counts = dict(word_counts)\n",
|
||||
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.words_indexes) - self.sequence_length\n",
|
||||
"\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" return (\n",
|
||||
" torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n",
|
||||
" torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9f178921",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self, vocab_size):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
" self.lstm_size = 128\n",
|
||||
" self.embedding_dim = 256\n",
|
||||
" self.num_layers = 3\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embedding_dim)\n",
|
||||
" self.lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)\n",
|
||||
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
|
||||
"\n",
|
||||
" def forward(self, x, prev_state = None):\n",
|
||||
" embed = self.embedding(x)\n",
|
||||
" output, state = self.lstm(embed, prev_state)\n",
|
||||
" logits = self.fc(output)\n",
|
||||
" return logits, state\n",
|
||||
"\n",
|
||||
" def init_state(self, sequence_length):\n",
|
||||
" zeros = torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device)\n",
|
||||
" return (zeros, zeros)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "39a27465",
|
||||
"id": "4227ef55",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -117,8 +45,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "2f25401d",
|
||||
"execution_count": 4,
|
||||
"id": "758cf94a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -126,142 +54,735 @@
|
||||
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
|
||||
"\n",
|
||||
"train_data = train_data[[6, 7]]\n",
|
||||
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
|
||||
"train_data = pd.concat([train_data, train_labels], axis=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "384922c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TrainCorpus:\n",
|
||||
" def __init__(self, data):\n",
|
||||
" self.data = data\n",
|
||||
" \n",
|
||||
"train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n",
|
||||
"train_data = train_data[['text']]\n",
|
||||
"\n",
|
||||
"with open('processed_train.txt', 'w', encoding='utf-8') as file:\n",
|
||||
" for _, row in train_data.iterrows():\n",
|
||||
" text = clean_text(str(row['text']))\n",
|
||||
" file.write(text + '\\n')"
|
||||
" def __iter__(self):\n",
|
||||
" for _, row in self.data.iterrows():\n",
|
||||
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
|
||||
" text = clean_text(text)\n",
|
||||
" yield word_tokenize(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "3da5f19b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sentences = TrainCorpus(train_data.head(100000))\n",
|
||||
"w2v_model = Word2Vec(vector_size=100, min_count=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "bb55ce42",
|
||||
"id": "183d43be",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprocessed_train.txt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36mDataset.__init__\u001b[1;34m(self, sequence_length, file_path)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
|
||||
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mword_to_index\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"97122\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data = Dataset(5, 'processed_train.txt')"
|
||||
"w2v_model.build_vocab(corpus_iterable=train_sentences)\n",
|
||||
"\n",
|
||||
"key_to_index = w2v_model.wv.key_to_index\n",
|
||||
"index_to_key = w2v_model.wv.index_to_key\n",
|
||||
"\n",
|
||||
"index_to_key.append('<unk>')\n",
|
||||
"key_to_index['<unk>'] = len(index_to_key) - 1\n",
|
||||
"\n",
|
||||
"vocab_size = len(index_to_key)\n",
|
||||
"print(vocab_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7a6872ca",
|
||||
"execution_count": 8,
|
||||
"id": "e63dd9fe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_vocab_size = len(data.uniq_words)\n",
|
||||
"print(data_vocab_size)"
|
||||
"class TrainDataset(torch.utils.data.IterableDataset):\n",
|
||||
" def __init__(self, data, index_to_key, key_to_index):\n",
|
||||
" self.data = data\n",
|
||||
" self.index_to_key = index_to_key\n",
|
||||
" self.key_to_index = key_to_index\n",
|
||||
" self.vocab_size = len(key_to_index)\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" for _, row in self.data.iterrows():\n",
|
||||
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
|
||||
" text = clean_text(text)\n",
|
||||
" tokens = word_tokenize(text)\n",
|
||||
" for i in range(5, len(tokens), 1):\n",
|
||||
" input_context = tokens[i-5:i]\n",
|
||||
" target_context = tokens[i-4:i+1]\n",
|
||||
" #gap_word = tokens[i]\n",
|
||||
" \n",
|
||||
" input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]\n",
|
||||
" target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]\n",
|
||||
" #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']\n",
|
||||
" #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])\n",
|
||||
" \n",
|
||||
" yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dc097469",
|
||||
"execution_count": 9,
|
||||
"id": "7c60ddc1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = Model(vocab_size = data_vocab_size).to(device)"
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self, embed_size, vocab_size):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
" self.embed_size = embed_size\n",
|
||||
" self.vocab_size = vocab_size\n",
|
||||
" self.gru_size = 128\n",
|
||||
" self.num_layers = 2\n",
|
||||
" \n",
|
||||
" self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)\n",
|
||||
" self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)\n",
|
||||
" self.fc = nn.Linear(self.gru_size, vocab_size)\n",
|
||||
"\n",
|
||||
" def forward(self, x, prev_state = None):\n",
|
||||
" embed = self.embed(x)\n",
|
||||
" output, state = self.gru(embed, prev_state)\n",
|
||||
" logits = self.fc(output)\n",
|
||||
" probs = torch.softmax(logits, dim=1)\n",
|
||||
" return logits, state\n",
|
||||
"\n",
|
||||
" def init_state(self, sequence_length):\n",
|
||||
" zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)\n",
|
||||
" return (zeros, zeros)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c9fcdfe7",
|
||||
"execution_count": 10,
|
||||
"id": "1c7b8fab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"\n",
|
||||
"def train(dataset, model, max_epochs, batch_size):\n",
|
||||
" model.train()\n",
|
||||
"\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
|
||||
" criterion = nn.CrossEntropyLoss()\n",
|
||||
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
||||
" optimizer = Adam(model.parameters(), lr=0.001)\n",
|
||||
"\n",
|
||||
" for epoch in range(max_epochs):\n",
|
||||
" for batch, (x, y) in enumerate(dataloader):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" \n",
|
||||
" x = x.to(device)\n",
|
||||
" y = y.to(device)\n",
|
||||
" \n",
|
||||
" y_pred, (state_h, state_c) = model(x)\n",
|
||||
" y_pred, state_h = model(x)\n",
|
||||
" loss = criterion(y_pred.transpose(1, 2), y)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() "
|
||||
" if batch % 1000 == 0:\n",
|
||||
" print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "38d52d9a",
|
||||
"execution_count": 11,
|
||||
"id": "3531d21d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict(dataset, model, text, next_words=5):\n",
|
||||
"train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "f72f1f6d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = Model(100, vocab_size).to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "d608d9fe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 0, update in batch 0/???, loss: 11.47425365447998\n",
|
||||
"epoch: 0, update in batch 1000/???, loss: 7.042205810546875\n",
|
||||
"epoch: 0, update in batch 2000/???, loss: 7.235440731048584\n",
|
||||
"epoch: 0, update in batch 3000/???, loss: 7.251269340515137\n",
|
||||
"epoch: 0, update in batch 4000/???, loss: 6.944191932678223\n",
|
||||
"epoch: 0, update in batch 5000/???, loss: 6.263372421264648\n",
|
||||
"epoch: 0, update in batch 6000/???, loss: 6.181947231292725\n",
|
||||
"epoch: 0, update in batch 7000/???, loss: 6.508013725280762\n",
|
||||
"epoch: 0, update in batch 8000/???, loss: 6.658236026763916\n",
|
||||
"epoch: 0, update in batch 9000/???, loss: 6.536279201507568\n",
|
||||
"epoch: 0, update in batch 10000/???, loss: 6.802626609802246\n",
|
||||
"epoch: 0, update in batch 11000/???, loss: 6.961029052734375\n",
|
||||
"epoch: 0, update in batch 12000/???, loss: 7.713824272155762\n",
|
||||
"epoch: 0, update in batch 13000/???, loss: 8.100411415100098\n",
|
||||
"epoch: 0, update in batch 14000/???, loss: 6.457145690917969\n",
|
||||
"epoch: 0, update in batch 15000/???, loss: 6.850286960601807\n",
|
||||
"epoch: 0, update in batch 16000/???, loss: 6.794063568115234\n",
|
||||
"epoch: 0, update in batch 17000/???, loss: 6.311314582824707\n",
|
||||
"epoch: 0, update in batch 18000/???, loss: 6.611917018890381\n",
|
||||
"epoch: 0, update in batch 19000/???, loss: 5.679810523986816\n",
|
||||
"epoch: 0, update in batch 20000/???, loss: 7.110655307769775\n",
|
||||
"epoch: 0, update in batch 21000/???, loss: 7.170722961425781\n",
|
||||
"epoch: 0, update in batch 22000/???, loss: 6.584908485412598\n",
|
||||
"epoch: 0, update in batch 23000/???, loss: 7.224095344543457\n",
|
||||
"epoch: 0, update in batch 24000/???, loss: 5.827445983886719\n",
|
||||
"epoch: 0, update in batch 25000/???, loss: 6.444586753845215\n",
|
||||
"epoch: 0, update in batch 26000/???, loss: 6.149054527282715\n",
|
||||
"epoch: 0, update in batch 27000/???, loss: 6.259871482849121\n",
|
||||
"epoch: 0, update in batch 28000/???, loss: 5.789839744567871\n",
|
||||
"epoch: 0, update in batch 29000/???, loss: 7.025563716888428\n",
|
||||
"epoch: 0, update in batch 30000/???, loss: 7.265492916107178\n",
|
||||
"epoch: 0, update in batch 31000/???, loss: 4.921586036682129\n",
|
||||
"epoch: 0, update in batch 32000/???, loss: 6.467754364013672\n",
|
||||
"epoch: 0, update in batch 33000/???, loss: 7.393715858459473\n",
|
||||
"epoch: 0, update in batch 34000/???, loss: 6.9696760177612305\n",
|
||||
"epoch: 0, update in batch 35000/???, loss: 7.276318550109863\n",
|
||||
"epoch: 0, update in batch 36000/???, loss: 7.011231899261475\n",
|
||||
"epoch: 0, update in batch 37000/???, loss: 7.029260158538818\n",
|
||||
"epoch: 0, update in batch 38000/???, loss: 6.723126411437988\n",
|
||||
"epoch: 0, update in batch 39000/???, loss: 6.828773498535156\n",
|
||||
"epoch: 0, update in batch 40000/???, loss: 6.069770336151123\n",
|
||||
"epoch: 0, update in batch 41000/???, loss: 6.651298522949219\n",
|
||||
"epoch: 0, update in batch 42000/???, loss: 7.455380916595459\n",
|
||||
"epoch: 0, update in batch 43000/???, loss: 5.594773769378662\n",
|
||||
"epoch: 0, update in batch 44000/???, loss: 6.102865219116211\n",
|
||||
"epoch: 0, update in batch 45000/???, loss: 6.04202127456665\n",
|
||||
"epoch: 0, update in batch 46000/???, loss: 6.472177982330322\n",
|
||||
"epoch: 0, update in batch 47000/???, loss: 5.870923042297363\n",
|
||||
"epoch: 0, update in batch 48000/???, loss: 6.286317348480225\n",
|
||||
"epoch: 0, update in batch 49000/???, loss: 7.157052516937256\n",
|
||||
"epoch: 0, update in batch 50000/???, loss: 5.888463020324707\n",
|
||||
"epoch: 0, update in batch 51000/???, loss: 5.609915733337402\n",
|
||||
"epoch: 0, update in batch 52000/???, loss: 6.565190315246582\n",
|
||||
"epoch: 0, update in batch 53000/???, loss: 6.4924468994140625\n",
|
||||
"epoch: 0, update in batch 54000/???, loss: 6.856420040130615\n",
|
||||
"epoch: 0, update in batch 55000/???, loss: 7.389428615570068\n",
|
||||
"epoch: 0, update in batch 56000/???, loss: 5.927685260772705\n",
|
||||
"epoch: 0, update in batch 57000/???, loss: 7.4227423667907715\n",
|
||||
"epoch: 0, update in batch 58000/???, loss: 6.46466064453125\n",
|
||||
"epoch: 0, update in batch 59000/???, loss: 6.586294651031494\n",
|
||||
"epoch: 0, update in batch 60000/???, loss: 5.797083854675293\n",
|
||||
"epoch: 0, update in batch 61000/???, loss: 4.825878143310547\n",
|
||||
"epoch: 0, update in batch 62000/???, loss: 6.911933898925781\n",
|
||||
"epoch: 0, update in batch 63000/???, loss: 7.684759616851807\n",
|
||||
"epoch: 0, update in batch 64000/???, loss: 5.716580390930176\n",
|
||||
"epoch: 0, update in batch 65000/???, loss: 6.1738667488098145\n",
|
||||
"epoch: 0, update in batch 66000/???, loss: 6.219714164733887\n",
|
||||
"epoch: 0, update in batch 67000/???, loss: 5.4024128913879395\n",
|
||||
"epoch: 0, update in batch 68000/???, loss: 6.912312984466553\n",
|
||||
"epoch: 0, update in batch 69000/???, loss: 6.703289031982422\n",
|
||||
"epoch: 0, update in batch 70000/???, loss: 7.375630855560303\n",
|
||||
"epoch: 0, update in batch 71000/???, loss: 5.757082462310791\n",
|
||||
"epoch: 0, update in batch 72000/???, loss: 5.992405414581299\n",
|
||||
"epoch: 0, update in batch 73000/???, loss: 6.706838130950928\n",
|
||||
"epoch: 0, update in batch 74000/???, loss: 7.376870155334473\n",
|
||||
"epoch: 0, update in batch 75000/???, loss: 6.676860809326172\n",
|
||||
"epoch: 0, update in batch 76000/???, loss: 5.904101848602295\n",
|
||||
"epoch: 0, update in batch 77000/???, loss: 6.776932716369629\n",
|
||||
"epoch: 0, update in batch 78000/???, loss: 5.682181358337402\n",
|
||||
"epoch: 0, update in batch 79000/???, loss: 6.211178302764893\n",
|
||||
"epoch: 0, update in batch 80000/???, loss: 6.366950035095215\n",
|
||||
"epoch: 0, update in batch 81000/???, loss: 5.25206184387207\n",
|
||||
"epoch: 0, update in batch 82000/???, loss: 6.30997371673584\n",
|
||||
"epoch: 0, update in batch 83000/???, loss: 6.351908206939697\n",
|
||||
"epoch: 0, update in batch 84000/???, loss: 7.659114837646484\n",
|
||||
"epoch: 0, update in batch 85000/???, loss: 6.5041704177856445\n",
|
||||
"epoch: 0, update in batch 86000/???, loss: 6.770291328430176\n",
|
||||
"epoch: 0, update in batch 87000/???, loss: 6.530011177062988\n",
|
||||
"epoch: 0, update in batch 88000/???, loss: 6.317249298095703\n",
|
||||
"epoch: 0, update in batch 89000/???, loss: 6.191559314727783\n",
|
||||
"epoch: 0, update in batch 90000/???, loss: 5.79150390625\n",
|
||||
"epoch: 0, update in batch 91000/???, loss: 6.356796741485596\n",
|
||||
"epoch: 0, update in batch 92000/???, loss: 7.3577141761779785\n",
|
||||
"epoch: 0, update in batch 93000/???, loss: 6.529308319091797\n",
|
||||
"epoch: 0, update in batch 94000/???, loss: 7.740485191345215\n",
|
||||
"epoch: 0, update in batch 95000/???, loss: 6.348109245300293\n",
|
||||
"epoch: 0, update in batch 96000/???, loss: 6.032902717590332\n",
|
||||
"epoch: 0, update in batch 97000/???, loss: 4.505112648010254\n",
|
||||
"epoch: 0, update in batch 98000/???, loss: 6.946290493011475\n",
|
||||
"epoch: 0, update in batch 99000/???, loss: 6.237973213195801\n",
|
||||
"epoch: 0, update in batch 100000/???, loss: 6.963421821594238\n",
|
||||
"epoch: 0, update in batch 101000/???, loss: 5.309017181396484\n",
|
||||
"epoch: 0, update in batch 102000/???, loss: 6.242384910583496\n",
|
||||
"epoch: 0, update in batch 103000/???, loss: 6.8203558921813965\n",
|
||||
"epoch: 0, update in batch 104000/???, loss: 6.242025852203369\n",
|
||||
"epoch: 0, update in batch 105000/???, loss: 6.765100002288818\n",
|
||||
"epoch: 0, update in batch 106000/???, loss: 6.8838043212890625\n",
|
||||
"epoch: 0, update in batch 107000/???, loss: 6.856662750244141\n",
|
||||
"epoch: 0, update in batch 108000/???, loss: 6.379549503326416\n",
|
||||
"epoch: 0, update in batch 109000/???, loss: 6.797707557678223\n",
|
||||
"epoch: 0, update in batch 110000/???, loss: 7.2699689865112305\n",
|
||||
"epoch: 0, update in batch 111000/???, loss: 7.040057182312012\n",
|
||||
"epoch: 0, update in batch 112000/???, loss: 6.7861223220825195\n",
|
||||
"epoch: 0, update in batch 113000/???, loss: 6.064489364624023\n",
|
||||
"epoch: 0, update in batch 114000/???, loss: 6.095967769622803\n",
|
||||
"epoch: 0, update in batch 115000/???, loss: 5.757347106933594\n",
|
||||
"epoch: 0, update in batch 116000/???, loss: 6.529908657073975\n",
|
||||
"epoch: 0, update in batch 117000/???, loss: 6.030801296234131\n",
|
||||
"epoch: 0, update in batch 118000/???, loss: 6.179767608642578\n",
|
||||
"epoch: 0, update in batch 119000/???, loss: 5.436234474182129\n",
|
||||
"epoch: 0, update in batch 120000/???, loss: 7.342876434326172\n",
|
||||
"epoch: 0, update in batch 121000/???, loss: 6.862719535827637\n",
|
||||
"epoch: 0, update in batch 122000/???, loss: 6.491606712341309\n",
|
||||
"epoch: 0, update in batch 123000/???, loss: 7.195406436920166\n",
|
||||
"epoch: 0, update in batch 124000/???, loss: 5.481313228607178\n",
|
||||
"epoch: 0, update in batch 125000/???, loss: 7.963885307312012\n",
|
||||
"epoch: 0, update in batch 126000/???, loss: 6.479039669036865\n",
|
||||
"epoch: 0, update in batch 127000/???, loss: 7.037934303283691\n",
|
||||
"epoch: 0, update in batch 128000/???, loss: 5.903053283691406\n",
|
||||
"epoch: 0, update in batch 129000/???, loss: 6.815878391265869\n",
|
||||
"epoch: 0, update in batch 130000/???, loss: 6.497969150543213\n",
|
||||
"epoch: 0, update in batch 131000/???, loss: 5.623625755310059\n",
|
||||
"epoch: 0, update in batch 132000/???, loss: 7.118441104888916\n",
|
||||
"epoch: 0, update in batch 133000/???, loss: 5.964345455169678\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 0, update in batch 134000/???, loss: 6.112139701843262\n",
|
||||
"epoch: 0, update in batch 135000/???, loss: 6.5865373611450195\n",
|
||||
"epoch: 0, update in batch 136000/???, loss: 7.498536109924316\n",
|
||||
"epoch: 0, update in batch 137000/???, loss: 7.124758243560791\n",
|
||||
"epoch: 0, update in batch 138000/???, loss: 6.871796607971191\n",
|
||||
"epoch: 0, update in batch 139000/???, loss: 5.8565263748168945\n",
|
||||
"epoch: 0, update in batch 140000/???, loss: 5.723143577575684\n",
|
||||
"epoch: 0, update in batch 141000/???, loss: 5.601426124572754\n",
|
||||
"epoch: 0, update in batch 142000/???, loss: 5.495566368103027\n",
|
||||
"epoch: 0, update in batch 143000/???, loss: 6.936192989349365\n",
|
||||
"epoch: 0, update in batch 144000/???, loss: 6.1843671798706055\n",
|
||||
"epoch: 0, update in batch 145000/???, loss: 6.886034965515137\n",
|
||||
"epoch: 0, update in batch 146000/???, loss: 6.655320167541504\n",
|
||||
"epoch: 0, update in batch 147000/???, loss: 6.46828556060791\n",
|
||||
"epoch: 0, update in batch 148000/???, loss: 5.607057571411133\n",
|
||||
"epoch: 0, update in batch 149000/???, loss: 7.182212829589844\n",
|
||||
"epoch: 0, update in batch 150000/???, loss: 7.241323947906494\n",
|
||||
"epoch: 0, update in batch 151000/???, loss: 7.308540344238281\n",
|
||||
"epoch: 0, update in batch 152000/???, loss: 5.267911434173584\n",
|
||||
"epoch: 0, update in batch 153000/???, loss: 5.895949363708496\n",
|
||||
"epoch: 0, update in batch 154000/???, loss: 6.629178524017334\n",
|
||||
"epoch: 0, update in batch 155000/???, loss: 4.9156012535095215\n",
|
||||
"epoch: 0, update in batch 156000/???, loss: 7.181819915771484\n",
|
||||
"epoch: 0, update in batch 157000/???, loss: 7.438391208648682\n",
|
||||
"epoch: 0, update in batch 158000/???, loss: 6.406006813049316\n",
|
||||
"epoch: 0, update in batch 159000/???, loss: 6.486207008361816\n",
|
||||
"epoch: 0, update in batch 160000/???, loss: 7.041951656341553\n",
|
||||
"epoch: 0, update in batch 161000/???, loss: 5.310082912445068\n",
|
||||
"epoch: 0, update in batch 162000/???, loss: 6.9074387550354\n",
|
||||
"epoch: 0, update in batch 163000/???, loss: 6.644919395446777\n",
|
||||
"epoch: 0, update in batch 164000/???, loss: 6.011733055114746\n",
|
||||
"epoch: 0, update in batch 165000/???, loss: 6.494180202484131\n",
|
||||
"epoch: 0, update in batch 166000/???, loss: 5.390150547027588\n",
|
||||
"epoch: 0, update in batch 167000/???, loss: 6.627297401428223\n",
|
||||
"epoch: 0, update in batch 168000/???, loss: 6.9020209312438965\n",
|
||||
"epoch: 0, update in batch 169000/???, loss: 7.317750453948975\n",
|
||||
"epoch: 0, update in batch 170000/???, loss: 5.69993782043457\n",
|
||||
"epoch: 0, update in batch 171000/???, loss: 6.658817291259766\n",
|
||||
"epoch: 0, update in batch 172000/???, loss: 6.422945976257324\n",
|
||||
"epoch: 0, update in batch 173000/???, loss: 5.822269439697266\n",
|
||||
"epoch: 0, update in batch 174000/???, loss: 6.513391017913818\n",
|
||||
"epoch: 0, update in batch 175000/???, loss: 5.886590957641602\n",
|
||||
"epoch: 0, update in batch 176000/???, loss: 7.119387149810791\n",
|
||||
"epoch: 0, update in batch 177000/???, loss: 6.933981418609619\n",
|
||||
"epoch: 0, update in batch 178000/???, loss: 6.678143501281738\n",
|
||||
"epoch: 0, update in batch 179000/???, loss: 6.890423774719238\n",
|
||||
"epoch: 0, update in batch 180000/???, loss: 6.932961940765381\n",
|
||||
"epoch: 0, update in batch 181000/???, loss: 6.650975704193115\n",
|
||||
"epoch: 0, update in batch 182000/???, loss: 6.732748985290527\n",
|
||||
"epoch: 0, update in batch 183000/???, loss: 6.064764976501465\n",
|
||||
"epoch: 0, update in batch 184000/???, loss: 5.282295227050781\n",
|
||||
"epoch: 0, update in batch 185000/???, loss: 6.569302558898926\n",
|
||||
"epoch: 0, update in batch 186000/???, loss: 5.800485610961914\n",
|
||||
"epoch: 0, update in batch 187000/???, loss: 6.175991058349609\n",
|
||||
"epoch: 0, update in batch 188000/???, loss: 5.405575752258301\n",
|
||||
"epoch: 0, update in batch 189000/???, loss: 6.191354751586914\n",
|
||||
"epoch: 0, update in batch 190000/???, loss: 6.156663417816162\n",
|
||||
"epoch: 0, update in batch 191000/???, loss: 6.937534332275391\n",
|
||||
"epoch: 0, update in batch 192000/???, loss: 6.562686920166016\n",
|
||||
"epoch: 0, update in batch 193000/???, loss: 6.639985084533691\n",
|
||||
"epoch: 0, update in batch 194000/???, loss: 7.285438537597656\n",
|
||||
"epoch: 0, update in batch 195000/???, loss: 6.528258323669434\n",
|
||||
"epoch: 0, update in batch 196000/???, loss: 8.326434135437012\n",
|
||||
"epoch: 0, update in batch 197000/???, loss: 6.781360626220703\n",
|
||||
"epoch: 0, update in batch 198000/???, loss: 7.223299980163574\n",
|
||||
"epoch: 0, update in batch 199000/???, loss: 6.411007881164551\n",
|
||||
"epoch: 0, update in batch 200000/???, loss: 5.885635852813721\n",
|
||||
"epoch: 0, update in batch 201000/???, loss: 5.706809043884277\n",
|
||||
"epoch: 0, update in batch 202000/???, loss: 6.230217933654785\n",
|
||||
"epoch: 0, update in batch 203000/???, loss: 7.056562900543213\n",
|
||||
"epoch: 0, update in batch 204000/???, loss: 7.2273077964782715\n",
|
||||
"epoch: 0, update in batch 205000/???, loss: 6.342462539672852\n",
|
||||
"epoch: 0, update in batch 206000/???, loss: 6.556817054748535\n",
|
||||
"epoch: 0, update in batch 207000/???, loss: 5.882349967956543\n",
|
||||
"epoch: 0, update in batch 208000/???, loss: 6.755805015563965\n",
|
||||
"epoch: 0, update in batch 209000/???, loss: 6.5045623779296875\n",
|
||||
"epoch: 0, update in batch 210000/???, loss: 6.525590419769287\n",
|
||||
"epoch: 0, update in batch 211000/???, loss: 6.49679708480835\n",
|
||||
"epoch: 0, update in batch 212000/???, loss: 6.562323093414307\n",
|
||||
"epoch: 0, update in batch 213000/???, loss: 5.227139472961426\n",
|
||||
"epoch: 0, update in batch 214000/???, loss: 7.044825077056885\n",
|
||||
"epoch: 0, update in batch 215000/???, loss: 6.002442359924316\n",
|
||||
"epoch: 0, update in batch 216000/???, loss: 6.084803581237793\n",
|
||||
"epoch: 0, update in batch 217000/???, loss: 7.425839900970459\n",
|
||||
"epoch: 0, update in batch 218000/???, loss: 6.818853855133057\n",
|
||||
"epoch: 0, update in batch 219000/???, loss: 7.0153374671936035\n",
|
||||
"epoch: 0, update in batch 220000/???, loss: 6.219962120056152\n",
|
||||
"epoch: 0, update in batch 221000/???, loss: 5.9975385665893555\n",
|
||||
"epoch: 0, update in batch 222000/???, loss: 6.480047702789307\n",
|
||||
"epoch: 0, update in batch 223000/???, loss: 6.405727386474609\n",
|
||||
"epoch: 0, update in batch 224000/???, loss: 4.7763471603393555\n",
|
||||
"epoch: 0, update in batch 225000/???, loss: 6.615710258483887\n",
|
||||
"epoch: 0, update in batch 226000/???, loss: 6.385044574737549\n",
|
||||
"epoch: 0, update in batch 227000/???, loss: 7.260453701019287\n",
|
||||
"epoch: 0, update in batch 228000/???, loss: 6.9794135093688965\n",
|
||||
"epoch: 0, update in batch 229000/???, loss: 6.235735893249512\n",
|
||||
"epoch: 0, update in batch 230000/???, loss: 6.478426456451416\n",
|
||||
"epoch: 0, update in batch 231000/???, loss: 6.181302547454834\n",
|
||||
"epoch: 0, update in batch 232000/???, loss: 5.826043128967285\n",
|
||||
"epoch: 0, update in batch 233000/???, loss: 5.9517011642456055\n",
|
||||
"epoch: 0, update in batch 234000/???, loss: 8.0064697265625\n",
|
||||
"epoch: 0, update in batch 235000/???, loss: 6.7822675704956055\n",
|
||||
"epoch: 0, update in batch 236000/???, loss: 6.293349742889404\n",
|
||||
"epoch: 0, update in batch 237000/???, loss: 6.442999362945557\n",
|
||||
"epoch: 0, update in batch 238000/???, loss: 6.282561302185059\n",
|
||||
"epoch: 0, update in batch 239000/???, loss: 7.166723728179932\n",
|
||||
"epoch: 0, update in batch 240000/???, loss: 7.189905643463135\n",
|
||||
"epoch: 0, update in batch 241000/???, loss: 8.462562561035156\n",
|
||||
"epoch: 0, update in batch 242000/???, loss: 7.446291923522949\n",
|
||||
"epoch: 0, update in batch 243000/???, loss: 6.382981777191162\n",
|
||||
"epoch: 0, update in batch 244000/???, loss: 7.635994911193848\n",
|
||||
"epoch: 0, update in batch 245000/???, loss: 6.635537147521973\n",
|
||||
"epoch: 0, update in batch 246000/???, loss: 6.068560600280762\n",
|
||||
"epoch: 0, update in batch 247000/???, loss: 6.193384170532227\n",
|
||||
"epoch: 0, update in batch 248000/???, loss: 5.702363967895508\n",
|
||||
"epoch: 0, update in batch 249000/???, loss: 6.09995174407959\n",
|
||||
"epoch: 0, update in batch 250000/???, loss: 6.312221050262451\n",
|
||||
"epoch: 0, update in batch 251000/???, loss: 5.853858470916748\n",
|
||||
"epoch: 0, update in batch 252000/???, loss: 5.886989593505859\n",
|
||||
"epoch: 0, update in batch 253000/???, loss: 5.801788330078125\n",
|
||||
"epoch: 0, update in batch 254000/???, loss: 6.032407760620117\n",
|
||||
"epoch: 0, update in batch 255000/???, loss: 7.480917453765869\n",
|
||||
"epoch: 0, update in batch 256000/???, loss: 6.578718662261963\n",
|
||||
"epoch: 0, update in batch 257000/???, loss: 6.344462871551514\n",
|
||||
"epoch: 0, update in batch 258000/???, loss: 5.939858436584473\n",
|
||||
"epoch: 0, update in batch 259000/???, loss: 5.181772232055664\n",
|
||||
"epoch: 0, update in batch 260000/???, loss: 6.640598297119141\n",
|
||||
"epoch: 0, update in batch 261000/???, loss: 7.189258575439453\n",
|
||||
"epoch: 0, update in batch 262000/???, loss: 6.2269287109375\n",
|
||||
"epoch: 0, update in batch 263000/???, loss: 5.8858795166015625\n",
|
||||
"epoch: 0, update in batch 264000/???, loss: 6.333988666534424\n",
|
||||
"epoch: 0, update in batch 265000/???, loss: 6.313681602478027\n",
|
||||
"epoch: 0, update in batch 266000/???, loss: 5.485809803009033\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 0, update in batch 267000/???, loss: 6.250800609588623\n",
|
||||
"epoch: 0, update in batch 268000/???, loss: 6.676806449890137\n",
|
||||
"epoch: 0, update in batch 269000/???, loss: 5.6487932205200195\n",
|
||||
"epoch: 0, update in batch 270000/???, loss: 6.648938179016113\n",
|
||||
"epoch: 0, update in batch 271000/???, loss: 6.26931095123291\n",
|
||||
"epoch: 0, update in batch 272000/???, loss: 5.343636512756348\n",
|
||||
"epoch: 0, update in batch 273000/???, loss: 7.051453590393066\n",
|
||||
"epoch: 0, update in batch 274000/???, loss: 4.578436851501465\n",
|
||||
"epoch: 0, update in batch 275000/???, loss: 5.400996685028076\n",
|
||||
"epoch: 0, update in batch 276000/???, loss: 6.129047870635986\n",
|
||||
"epoch: 0, update in batch 277000/???, loss: 7.549851894378662\n",
|
||||
"epoch: 0, update in batch 278000/???, loss: 6.093559265136719\n",
|
||||
"epoch: 0, update in batch 279000/???, loss: 5.6921467781066895\n",
|
||||
"epoch: 0, update in batch 280000/???, loss: 5.789463996887207\n",
|
||||
"epoch: 0, update in batch 281000/???, loss: 5.681942939758301\n",
|
||||
"epoch: 0, update in batch 282000/???, loss: 6.750497341156006\n",
|
||||
"epoch: 0, update in batch 283000/???, loss: 5.960292339324951\n",
|
||||
"epoch: 0, update in batch 284000/???, loss: 6.160388469696045\n",
|
||||
"epoch: 0, update in batch 285000/???, loss: 7.137685298919678\n",
|
||||
"epoch: 0, update in batch 286000/???, loss: 7.7431464195251465\n",
|
||||
"epoch: 0, update in batch 287000/???, loss: 5.229738712310791\n",
|
||||
"epoch: 0, update in batch 288000/???, loss: 6.654232025146484\n",
|
||||
"epoch: 0, update in batch 289000/???, loss: 6.229329586029053\n",
|
||||
"epoch: 0, update in batch 290000/???, loss: 7.188180446624756\n",
|
||||
"epoch: 0, update in batch 291000/???, loss: 6.244111061096191\n",
|
||||
"epoch: 0, update in batch 292000/???, loss: 7.199154853820801\n",
|
||||
"epoch: 0, update in batch 293000/???, loss: 7.1866865158081055\n",
|
||||
"epoch: 0, update in batch 294000/???, loss: 6.574115753173828\n",
|
||||
"epoch: 0, update in batch 295000/???, loss: 6.487138271331787\n",
|
||||
"epoch: 0, update in batch 296000/???, loss: 5.813161849975586\n",
|
||||
"epoch: 0, update in batch 297000/???, loss: 6.159414291381836\n",
|
||||
"epoch: 0, update in batch 298000/???, loss: 7.256616115570068\n",
|
||||
"epoch: 0, update in batch 299000/???, loss: 7.511231899261475\n",
|
||||
"epoch: 0, update in batch 300000/???, loss: 6.148821830749512\n",
|
||||
"epoch: 0, update in batch 301000/???, loss: 7.108969211578369\n",
|
||||
"epoch: 0, update in batch 302000/???, loss: 6.528176307678223\n",
|
||||
"epoch: 0, update in batch 303000/???, loss: 6.276839256286621\n",
|
||||
"epoch: 0, update in batch 304000/???, loss: 6.484020233154297\n",
|
||||
"epoch: 0, update in batch 305000/???, loss: 6.38557767868042\n",
|
||||
"epoch: 0, update in batch 306000/???, loss: 7.068814754486084\n",
|
||||
"epoch: 0, update in batch 307000/???, loss: 5.844017505645752\n",
|
||||
"epoch: 0, update in batch 308000/???, loss: 4.25785493850708\n",
|
||||
"epoch: 0, update in batch 309000/???, loss: 6.709985256195068\n",
|
||||
"epoch: 0, update in batch 310000/???, loss: 6.543104648590088\n",
|
||||
"epoch: 0, update in batch 311000/???, loss: 6.675828456878662\n",
|
||||
"epoch: 0, update in batch 312000/???, loss: 5.82969856262207\n",
|
||||
"epoch: 0, update in batch 313000/???, loss: 6.05246639251709\n",
|
||||
"epoch: 0, update in batch 314000/???, loss: 7.2366156578063965\n",
|
||||
"epoch: 0, update in batch 315000/???, loss: 5.039820194244385\n",
|
||||
"epoch: 0, update in batch 316000/???, loss: 5.943173885345459\n",
|
||||
"epoch: 0, update in batch 317000/???, loss: 6.2509002685546875\n",
|
||||
"epoch: 0, update in batch 318000/???, loss: 6.451228141784668\n",
|
||||
"epoch: 0, update in batch 319000/???, loss: 6.6381049156188965\n",
|
||||
"epoch: 0, update in batch 320000/???, loss: 6.570329189300537\n",
|
||||
"epoch: 0, update in batch 321000/???, loss: 5.376622200012207\n",
|
||||
"epoch: 0, update in batch 322000/???, loss: 6.487462520599365\n",
|
||||
"epoch: 0, update in batch 323000/???, loss: 6.676497459411621\n",
|
||||
"epoch: 0, update in batch 324000/???, loss: 6.283420562744141\n",
|
||||
"epoch: 0, update in batch 325000/???, loss: 6.164648532867432\n",
|
||||
"epoch: 0, update in batch 326000/???, loss: 6.839153289794922\n",
|
||||
"epoch: 0, update in batch 327000/???, loss: 6.435141086578369\n",
|
||||
"epoch: 0, update in batch 328000/???, loss: 6.160590171813965\n",
|
||||
"epoch: 0, update in batch 329000/???, loss: 5.876160621643066\n",
|
||||
"epoch: 0, update in batch 330000/???, loss: 6.47445011138916\n",
|
||||
"epoch: 0, update in batch 331000/???, loss: 6.294231414794922\n",
|
||||
"epoch: 0, update in batch 332000/???, loss: 6.099027156829834\n",
|
||||
"epoch: 0, update in batch 333000/???, loss: 6.986542701721191\n",
|
||||
"epoch: 0, update in batch 334000/???, loss: 7.018263816833496\n",
|
||||
"epoch: 0, update in batch 335000/???, loss: 6.906959533691406\n",
|
||||
"epoch: 0, update in batch 336000/???, loss: 6.12356424331665\n",
|
||||
"epoch: 0, update in batch 337000/???, loss: 6.316069602966309\n",
|
||||
"epoch: 0, update in batch 338000/???, loss: 6.908566474914551\n",
|
||||
"epoch: 0, update in batch 339000/???, loss: 5.628839492797852\n",
|
||||
"epoch: 0, update in batch 340000/???, loss: 7.069979667663574\n",
|
||||
"epoch: 0, update in batch 341000/???, loss: 5.350735187530518\n",
|
||||
"epoch: 0, update in batch 342000/???, loss: 5.377245903015137\n",
|
||||
"epoch: 0, update in batch 343000/???, loss: 5.2340989112854\n",
|
||||
"epoch: 0, update in batch 344000/???, loss: 6.087491512298584\n",
|
||||
"epoch: 0, update in batch 345000/???, loss: 6.162985801696777\n",
|
||||
"epoch: 0, update in batch 346000/???, loss: 5.655491828918457\n",
|
||||
"epoch: 0, update in batch 347000/???, loss: 5.311842918395996\n",
|
||||
"epoch: 0, update in batch 348000/???, loss: 7.577170372009277\n",
|
||||
"epoch: 0, update in batch 349000/???, loss: 6.730460166931152\n",
|
||||
"epoch: 0, update in batch 350000/???, loss: 6.782231330871582\n",
|
||||
"epoch: 0, update in batch 351000/???, loss: 6.789486885070801\n",
|
||||
"epoch: 0, update in batch 352000/???, loss: 5.473587989807129\n",
|
||||
"epoch: 0, update in batch 353000/???, loss: 5.531443119049072\n",
|
||||
"epoch: 0, update in batch 354000/???, loss: 7.220989227294922\n",
|
||||
"epoch: 0, update in batch 355000/???, loss: 5.954288005828857\n",
|
||||
"epoch: 0, update in batch 356000/???, loss: 4.112783432006836\n",
|
||||
"epoch: 0, update in batch 357000/???, loss: 5.409672737121582\n",
|
||||
"epoch: 0, update in batch 358000/???, loss: 6.408724784851074\n",
|
||||
"epoch: 0, update in batch 359000/???, loss: 6.744941711425781\n",
|
||||
"epoch: 0, update in batch 360000/???, loss: 6.218225479125977\n",
|
||||
"epoch: 0, update in batch 361000/???, loss: 6.071394920349121\n",
|
||||
"epoch: 0, update in batch 362000/???, loss: 6.137121677398682\n",
|
||||
"epoch: 0, update in batch 363000/???, loss: 5.876864433288574\n",
|
||||
"epoch: 0, update in batch 364000/???, loss: 7.715426445007324\n",
|
||||
"epoch: 0, update in batch 365000/???, loss: 6.217362880706787\n",
|
||||
"epoch: 0, update in batch 366000/???, loss: 6.741396903991699\n",
|
||||
"epoch: 0, update in batch 367000/???, loss: 6.4564313888549805\n",
|
||||
"epoch: 0, update in batch 368000/???, loss: 6.994439601898193\n",
|
||||
"epoch: 0, update in batch 369000/???, loss: 6.061278820037842\n",
|
||||
"epoch: 0, update in batch 370000/???, loss: 4.894576549530029\n",
|
||||
"epoch: 0, update in batch 371000/???, loss: 6.351264953613281\n",
|
||||
"epoch: 0, update in batch 372000/???, loss: 6.826904296875\n",
|
||||
"epoch: 0, update in batch 373000/???, loss: 6.090312480926514\n",
|
||||
"epoch: 0, update in batch 374000/???, loss: 5.797528266906738\n",
|
||||
"epoch: 0, update in batch 375000/???, loss: 7.3235673904418945\n",
|
||||
"epoch: 0, update in batch 376000/???, loss: 5.5752973556518555\n",
|
||||
"epoch: 0, update in batch 377000/???, loss: 6.29438591003418\n",
|
||||
"epoch: 0, update in batch 378000/???, loss: 5.238917827606201\n",
|
||||
"epoch: 0, update in batch 379000/???, loss: 5.542972564697266\n",
|
||||
"epoch: 0, update in batch 380000/???, loss: 6.5024614334106445\n",
|
||||
"epoch: 0, update in batch 381000/???, loss: 6.918997287750244\n",
|
||||
"epoch: 0, update in batch 382000/???, loss: 5.694029331207275\n",
|
||||
"epoch: 0, update in batch 383000/???, loss: 7.109190940856934\n",
|
||||
"epoch: 0, update in batch 384000/???, loss: 5.214654445648193\n",
|
||||
"epoch: 0, update in batch 385000/???, loss: 7.055975437164307\n",
|
||||
"epoch: 0, update in batch 386000/???, loss: 6.443846225738525\n",
|
||||
"epoch: 0, update in batch 387000/???, loss: 5.544674873352051\n",
|
||||
"epoch: 0, update in batch 388000/???, loss: 6.936171531677246\n",
|
||||
"epoch: 0, update in batch 389000/???, loss: 6.646860599517822\n",
|
||||
"epoch: 0, update in batch 390000/???, loss: 6.193584442138672\n",
|
||||
"epoch: 0, update in batch 391000/???, loss: 5.9077558517456055\n",
|
||||
"epoch: 0, update in batch 392000/???, loss: 5.029908657073975\n",
|
||||
"epoch: 0, update in batch 393000/???, loss: 6.725222587585449\n",
|
||||
"epoch: 0, update in batch 394000/???, loss: 6.814855098724365\n",
|
||||
"epoch: 0, update in batch 395000/???, loss: 7.396543979644775\n",
|
||||
"epoch: 0, update in batch 396000/???, loss: 6.993375301361084\n",
|
||||
"epoch: 0, update in batch 397000/???, loss: 6.224326133728027\n",
|
||||
"epoch: 0, update in batch 398000/???, loss: 6.301025390625\n",
|
||||
"epoch: 0, update in batch 399000/???, loss: 6.707190036773682\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 0, update in batch 400000/???, loss: 6.7646660804748535\n",
|
||||
"epoch: 0, update in batch 401000/???, loss: 5.308177471160889\n",
|
||||
"epoch: 0, update in batch 402000/???, loss: 8.35996150970459\n",
|
||||
"epoch: 0, update in batch 403000/???, loss: 5.825610160827637\n",
|
||||
"epoch: 0, update in batch 404000/???, loss: 6.310220718383789\n",
|
||||
"epoch: 0, update in batch 405000/???, loss: 5.759210109710693\n",
|
||||
"epoch: 0, update in batch 406000/???, loss: 6.32699728012085\n",
|
||||
"epoch: 0, update in batch 407000/???, loss: 5.659378528594971\n",
|
||||
"epoch: 0, update in batch 408000/???, loss: 6.216103553771973\n",
|
||||
"epoch: 0, update in batch 409000/???, loss: 5.666914463043213\n",
|
||||
"epoch: 0, update in batch 410000/???, loss: 6.419122219085693\n",
|
||||
"epoch: 0, update in batch 411000/???, loss: 5.372750282287598\n",
|
||||
"epoch: 0, update in batch 412000/???, loss: 6.839580535888672\n",
|
||||
"epoch: 0, update in batch 413000/???, loss: 6.7682647705078125\n",
|
||||
"epoch: 0, update in batch 414000/???, loss: 5.951648235321045\n",
|
||||
"epoch: 0, update in batch 415000/???, loss: 6.181953430175781\n",
|
||||
"epoch: 0, update in batch 416000/???, loss: 5.475704669952393\n",
|
||||
"epoch: 0, update in batch 417000/???, loss: 6.383082866668701\n",
|
||||
"epoch: 0, update in batch 418000/???, loss: 6.8107590675354\n",
|
||||
"epoch: 0, update in batch 419000/???, loss: 5.753104209899902\n",
|
||||
"epoch: 0, update in batch 420000/???, loss: 5.320840835571289\n",
|
||||
"epoch: 0, update in batch 421000/???, loss: 7.377203464508057\n",
|
||||
"epoch: 0, update in batch 422000/???, loss: 6.5706048011779785\n",
|
||||
"epoch: 0, update in batch 423000/???, loss: 5.032872676849365\n",
|
||||
"epoch: 0, update in batch 424000/???, loss: 5.781243324279785\n",
|
||||
"epoch: 0, update in batch 425000/???, loss: 6.160118579864502\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train(train_dataset, model, 1, 64)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"id": "36a1b802",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_probs(tokens):\n",
|
||||
" model.eval()\n",
|
||||
" words = text.split(' ')\n",
|
||||
" state_h, state_c = model.init_state(len(words))\n",
|
||||
" state_h = model.init_state(len(tokens))\n",
|
||||
"\n",
|
||||
" for i in range(0, next_words):\n",
|
||||
" x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n",
|
||||
" y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n",
|
||||
" x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)\n",
|
||||
" y_pred, state_h = model(x)\n",
|
||||
"\n",
|
||||
" last_word_logits = y_pred[0][-1]\n",
|
||||
" p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
|
||||
" word_index = np.random.choice(len(last_word_logits), p=p)\n",
|
||||
" words.append(dataset.index_to_word[word_index])\n",
|
||||
" probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
|
||||
" word_index = np.random.choice(len(last_word_logits), p=probs)\n",
|
||||
" \n",
|
||||
" return words"
|
||||
" top_words = []\n",
|
||||
" for index in range(len(probs)):\n",
|
||||
" if len(top_words) < 30:\n",
|
||||
" top_words.append((probs[index], [index]))\n",
|
||||
" else:\n",
|
||||
" worst_word = None\n",
|
||||
" for word in top_words:\n",
|
||||
" if not worst_word:\n",
|
||||
" worst_word = word\n",
|
||||
" else:\n",
|
||||
" if word[0] < worst_word[0]:\n",
|
||||
" worst_word = word\n",
|
||||
" if worst_word[0] < probs[index] and index != len(probs) - 1:\n",
|
||||
" top_words.remove(worst_word)\n",
|
||||
" top_words.append((probs[index], [index]))\n",
|
||||
" \n",
|
||||
" prediction = ''\n",
|
||||
" sum_prob = 0.0\n",
|
||||
" for word in top_words:\n",
|
||||
" sum_prob += word[0]\n",
|
||||
" word_index = word[0]\n",
|
||||
" word_text = index_to_key[word[1][0]]\n",
|
||||
" prediction += f'{word_text}:{word_index} '\n",
|
||||
" prediction += f':{1 - sum_prob}'\n",
|
||||
" \n",
|
||||
" return prediction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bcd99ea4",
|
||||
"execution_count": 56,
|
||||
"id": "155636b5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train(data, model, 10, 128)"
|
||||
"dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
|
||||
"test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "df052344",
|
||||
"execution_count": 59,
|
||||
"id": "99b1d944",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), 'model.pt')"
|
||||
"with open('dev-0/out.tsv', 'w') as file:\n",
|
||||
" for index, row in dev_data.iterrows():\n",
|
||||
" left_text = clean_text(str(row[6]))\n",
|
||||
" left_words = word_tokenize(left_text)\n",
|
||||
" if len(left_words) < 6:\n",
|
||||
" prediction = ':1.0'\n",
|
||||
" else:\n",
|
||||
" prediction = predict_probs(left_words[-5:])\n",
|
||||
" file.write(prediction + '\\n')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6e87339b",
|
||||
"execution_count": 60,
|
||||
"id": "186c3269",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"with open('test-A/out.tsv', 'w') as file:\n",
|
||||
" for index, row in test_data.iterrows():\n",
|
||||
" left_text = clean_text(str(row[6]))\n",
|
||||
" left_words = word_tokenize(left_text)\n",
|
||||
" if len(left_words) < 6:\n",
|
||||
" prediction = ':1.0'\n",
|
||||
" else:\n",
|
||||
" prediction = predict_probs(left_words[-5:])\n",
|
||||
" file.write(prediction + '\\n')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
259
run.py
Normal file
259
run.py
Normal file
@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# In[1]:
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import csv
|
||||
import torch
|
||||
from torch import nn
|
||||
from gensim.models import Word2Vec
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
def clean_text(text):
|
||||
text = text.lower().replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ')
|
||||
text = re.sub(r'\p{P}', '', text)
|
||||
text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
|
||||
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
|
||||
|
||||
train_data = train_data[[6, 7]]
|
||||
train_data = pd.concat([train_data, train_labels], axis=1)
|
||||
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
class TrainCorpus:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __iter__(self):
|
||||
for _, row in self.data.iterrows():
|
||||
text = str(row[6]) + str(row[0]) + str(row[7])
|
||||
text = clean_text(text)
|
||||
yield word_tokenize(text)
|
||||
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
train_sentences = TrainCorpus(train_data.head(100000))
|
||||
w2v_model = Word2Vec(vector_size=100, min_count=10)
|
||||
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
w2v_model.build_vocab(corpus_iterable=train_sentences)
|
||||
|
||||
key_to_index = w2v_model.wv.key_to_index
|
||||
index_to_key = w2v_model.wv.index_to_key
|
||||
|
||||
index_to_key.append('<unk>')
|
||||
key_to_index['<unk>'] = len(index_to_key) - 1
|
||||
|
||||
vocab_size = len(index_to_key)
|
||||
print(vocab_size)
|
||||
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
class TrainDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, data, index_to_key, key_to_index):
|
||||
self.data = data
|
||||
self.index_to_key = index_to_key
|
||||
self.key_to_index = key_to_index
|
||||
self.vocab_size = len(key_to_index)
|
||||
|
||||
def __iter__(self):
|
||||
for _, row in self.data.iterrows():
|
||||
text = str(row[6]) + str(row[0]) + str(row[7])
|
||||
text = clean_text(text)
|
||||
tokens = word_tokenize(text)
|
||||
for i in range(5, len(tokens), 1):
|
||||
input_context = tokens[i-5:i]
|
||||
target_context = tokens[i-4:i+1]
|
||||
#gap_word = tokens[i]
|
||||
|
||||
input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
|
||||
target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
|
||||
#word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
|
||||
#word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
|
||||
|
||||
yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
|
||||
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, embed_size, vocab_size):
|
||||
super(Model, self).__init__()
|
||||
self.embed_size = embed_size
|
||||
self.vocab_size = vocab_size
|
||||
self.gru_size = 128
|
||||
self.num_layers = 2
|
||||
|
||||
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
|
||||
self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)
|
||||
self.fc = nn.Linear(self.gru_size, vocab_size)
|
||||
|
||||
def forward(self, x, prev_state = None):
|
||||
embed = self.embed(x)
|
||||
output, state = self.gru(embed, prev_state)
|
||||
logits = self.fc(output)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
return logits, state
|
||||
|
||||
def init_state(self, sequence_length):
|
||||
zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
|
||||
return (zeros, zeros)
|
||||
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import Adam
|
||||
|
||||
def train(dataset, model, max_epochs, batch_size):
|
||||
model.train()
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for batch, (x, y) in enumerate(dataloader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
|
||||
y_pred, state_h = model(x)
|
||||
loss = criterion(y_pred.transpose(1, 2), y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch % 1000 == 0:
|
||||
print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
|
||||
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)
|
||||
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
model = Model(100, vocab_size).to(device)
|
||||
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
train(train_dataset, model, 1, 64)
|
||||
|
||||
|
||||
# In[58]:
|
||||
|
||||
|
||||
def predict_probs(tokens):
|
||||
model.eval()
|
||||
state_h = model.init_state(len(tokens))
|
||||
|
||||
x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)
|
||||
y_pred, state_h = model(x)
|
||||
|
||||
last_word_logits = y_pred[0][-1]
|
||||
probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
|
||||
word_index = np.random.choice(len(last_word_logits), p=probs)
|
||||
|
||||
top_words = []
|
||||
for index in range(len(probs)):
|
||||
if len(top_words) < 30:
|
||||
top_words.append((probs[index], [index]))
|
||||
else:
|
||||
worst_word = None
|
||||
for word in top_words:
|
||||
if not worst_word:
|
||||
worst_word = word
|
||||
else:
|
||||
if word[0] < worst_word[0]:
|
||||
worst_word = word
|
||||
if worst_word[0] < probs[index] and index != len(probs) - 1:
|
||||
top_words.remove(worst_word)
|
||||
top_words.append((probs[index], [index]))
|
||||
|
||||
prediction = ''
|
||||
sum_prob = 0.0
|
||||
for word in top_words:
|
||||
sum_prob += word[0]
|
||||
word_index = word[0]
|
||||
word_text = index_to_key[word[1][0]]
|
||||
prediction += f'{word_text}:{word_index} '
|
||||
prediction += f':{1 - sum_prob}'
|
||||
|
||||
return prediction
|
||||
|
||||
|
||||
# In[56]:
|
||||
|
||||
|
||||
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
|
||||
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
|
||||
|
||||
|
||||
# In[59]:
|
||||
|
||||
|
||||
with open('dev-0/out.tsv', 'w') as file:
|
||||
for index, row in dev_data.iterrows():
|
||||
left_text = clean_text(str(row[6]))
|
||||
left_words = word_tokenize(left_text)
|
||||
if len(left_words) < 6:
|
||||
prediction = ':1.0'
|
||||
else:
|
||||
prediction = predict_probs(left_words[-5:])
|
||||
file.write(prediction + '\n')
|
||||
|
||||
|
||||
# In[60]:
|
||||
|
||||
|
||||
with open('test-A/out.tsv', 'w') as file:
|
||||
for index, row in test_data.iterrows():
|
||||
left_text = clean_text(str(row[6]))
|
||||
left_words = word_tokenize(left_text)
|
||||
if len(left_words) < 6:
|
||||
prediction = ':1.0'
|
||||
else:
|
||||
prediction = predict_probs(left_words[-5:])
|
||||
file.write(prediction + '\n')
|
||||
|
7414
test-A/out.tsv
Normal file
7414
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user