final solution

This commit is contained in:
Łukasz Jędyk 2022-05-28 15:36:41 +02:00
parent eacc27a0d2
commit 96313d4915
5 changed files with 18854 additions and 140 deletions

3
.gitignore vendored
View File

@ -1,2 +1,3 @@
.ipynb_checkpoints/
processed_train.txt
processed_train.txt
model/

10519
dev-0/out.tsv Normal file

File diff suppressed because it is too large Load Diff

799
run.ipynb
View File

@ -3,22 +3,24 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "56fb2fdb",
"id": "03de852a",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import regex as re\n",
"import torch\n",
"import csv\n",
"import torch\n",
"from torch import nn\n",
"from collections import Counter"
"from gensim.models import Word2Vec\n",
"from nltk.tokenize import word_tokenize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "74f269c9",
"id": "73497953",
"metadata": {},
"outputs": [],
"source": [
@ -29,81 +31,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "c939f600",
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, sequence_length, file_path):\n",
" self.file_path = file_path\n",
" self.sequence_length = sequence_length\n",
" self.words = self.load()\n",
" self.uniq_words = self.get_uniq_words()\n",
"\n",
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
" \n",
" self.index_to_word[-1] = '<UNK>'\n",
" self.word_to_index['<UNK>'] = -1\n",
"\n",
" self.words_indexes = [self.word_to_index[w] if w in self.uniq_words else self.word_to_index['<UNK>'] for w in self.words]\n",
"\n",
" def load(self):\n",
" with open(self.file_path, 'r') as f_in:\n",
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
" text = ' '.join(text).lower()\n",
" text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n",
" text = text.split(' ')\n",
" return text\n",
" \n",
" def get_uniq_words(self):\n",
" word_counts = Counter(self.words).most_common(250000)\n",
" word_counts = dict(word_counts)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n",
" def __len__(self):\n",
" return len(self.words_indexes) - self.sequence_length\n",
"\n",
" def __getitem__(self, index):\n",
" return (\n",
" torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n",
" torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9f178921",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.lstm_size = 128\n",
" self.embedding_dim = 256\n",
" self.num_layers = 3\n",
"\n",
" self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embedding_dim)\n",
" self.lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embedding(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" zeros = torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device)\n",
" return (zeros, zeros)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "39a27465",
"id": "4227ef55",
"metadata": {},
"outputs": [],
"source": [
@ -117,8 +45,8 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2f25401d",
"execution_count": 4,
"id": "758cf94a",
"metadata": {},
"outputs": [],
"source": [
@ -126,142 +54,735 @@
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"\n",
"train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
"\n",
"train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n",
"train_data = train_data[['text']]\n",
"\n",
"with open('processed_train.txt', 'w', encoding='utf-8') as file:\n",
" for _, row in train_data.iterrows():\n",
" text = clean_text(str(row['text']))\n",
" file.write(text + '\\n')"
"train_data = pd.concat([train_data, train_labels], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "384922c7",
"metadata": {},
"outputs": [],
"source": [
"class TrainCorpus:\n",
" def __init__(self, data):\n",
" self.data = data\n",
" \n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" yield word_tokenize(text)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3da5f19b",
"metadata": {},
"outputs": [],
"source": [
"train_sentences = TrainCorpus(train_data.head(100000))\n",
"w2v_model = Word2Vec(vector_size=100, min_count=10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bb55ce42",
"id": "183d43be",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprocessed_train.txt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36mDataset.__init__\u001b[1;34m(self, sequence_length, file_path)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mword_to_index\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
"name": "stdout",
"output_type": "stream",
"text": [
"97122\n"
]
}
],
"source": [
"data = Dataset(5, 'processed_train.txt')"
"w2v_model.build_vocab(corpus_iterable=train_sentences)\n",
"\n",
"key_to_index = w2v_model.wv.key_to_index\n",
"index_to_key = w2v_model.wv.index_to_key\n",
"\n",
"index_to_key.append('<unk>')\n",
"key_to_index['<unk>'] = len(index_to_key) - 1\n",
"\n",
"vocab_size = len(index_to_key)\n",
"print(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a6872ca",
"execution_count": 8,
"id": "e63dd9fe",
"metadata": {},
"outputs": [],
"source": [
"data_vocab_size = len(data.uniq_words)\n",
"print(data_vocab_size)"
"class TrainDataset(torch.utils.data.IterableDataset):\n",
" def __init__(self, data, index_to_key, key_to_index):\n",
" self.data = data\n",
" self.index_to_key = index_to_key\n",
" self.key_to_index = key_to_index\n",
" self.vocab_size = len(key_to_index)\n",
"\n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" tokens = word_tokenize(text)\n",
" for i in range(5, len(tokens), 1):\n",
" input_context = tokens[i-5:i]\n",
" target_context = tokens[i-4:i+1]\n",
" #gap_word = tokens[i]\n",
" \n",
" input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]\n",
" target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]\n",
" #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']\n",
" #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])\n",
" \n",
" yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc097469",
"execution_count": 9,
"id": "7c60ddc1",
"metadata": {},
"outputs": [],
"source": [
"model = Model(vocab_size = data_vocab_size).to(device)"
"class Model(nn.Module):\n",
" def __init__(self, embed_size, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.embed_size = embed_size\n",
" self.vocab_size = vocab_size\n",
" self.gru_size = 128\n",
" self.num_layers = 2\n",
" \n",
" self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)\n",
" self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)\n",
" self.fc = nn.Linear(self.gru_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embed(x)\n",
" output, state = self.gru(embed, prev_state)\n",
" logits = self.fc(output)\n",
" probs = torch.softmax(logits, dim=1)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)\n",
" return (zeros, zeros)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9fcdfe7",
"execution_count": 10,
"id": "1c7b8fab",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torch.optim import Adam\n",
"\n",
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
" optimizer = Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" \n",
" x = x.to(device)\n",
" y = y.to(device)\n",
"\n",
" y_pred, (state_h, state_c) = model(x)\n",
" \n",
" y_pred, state_h = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() "
" \n",
" if batch % 1000 == 0:\n",
" print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38d52d9a",
"execution_count": 11,
"id": "3531d21d",
"metadata": {},
"outputs": [],
"source": [
"def predict(dataset, model, text, next_words=5):\n",
"train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f72f1f6d",
"metadata": {},
"outputs": [],
"source": [
"model = Model(100, vocab_size).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d608d9fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 0/???, loss: 11.47425365447998\n",
"epoch: 0, update in batch 1000/???, loss: 7.042205810546875\n",
"epoch: 0, update in batch 2000/???, loss: 7.235440731048584\n",
"epoch: 0, update in batch 3000/???, loss: 7.251269340515137\n",
"epoch: 0, update in batch 4000/???, loss: 6.944191932678223\n",
"epoch: 0, update in batch 5000/???, loss: 6.263372421264648\n",
"epoch: 0, update in batch 6000/???, loss: 6.181947231292725\n",
"epoch: 0, update in batch 7000/???, loss: 6.508013725280762\n",
"epoch: 0, update in batch 8000/???, loss: 6.658236026763916\n",
"epoch: 0, update in batch 9000/???, loss: 6.536279201507568\n",
"epoch: 0, update in batch 10000/???, loss: 6.802626609802246\n",
"epoch: 0, update in batch 11000/???, loss: 6.961029052734375\n",
"epoch: 0, update in batch 12000/???, loss: 7.713824272155762\n",
"epoch: 0, update in batch 13000/???, loss: 8.100411415100098\n",
"epoch: 0, update in batch 14000/???, loss: 6.457145690917969\n",
"epoch: 0, update in batch 15000/???, loss: 6.850286960601807\n",
"epoch: 0, update in batch 16000/???, loss: 6.794063568115234\n",
"epoch: 0, update in batch 17000/???, loss: 6.311314582824707\n",
"epoch: 0, update in batch 18000/???, loss: 6.611917018890381\n",
"epoch: 0, update in batch 19000/???, loss: 5.679810523986816\n",
"epoch: 0, update in batch 20000/???, loss: 7.110655307769775\n",
"epoch: 0, update in batch 21000/???, loss: 7.170722961425781\n",
"epoch: 0, update in batch 22000/???, loss: 6.584908485412598\n",
"epoch: 0, update in batch 23000/???, loss: 7.224095344543457\n",
"epoch: 0, update in batch 24000/???, loss: 5.827445983886719\n",
"epoch: 0, update in batch 25000/???, loss: 6.444586753845215\n",
"epoch: 0, update in batch 26000/???, loss: 6.149054527282715\n",
"epoch: 0, update in batch 27000/???, loss: 6.259871482849121\n",
"epoch: 0, update in batch 28000/???, loss: 5.789839744567871\n",
"epoch: 0, update in batch 29000/???, loss: 7.025563716888428\n",
"epoch: 0, update in batch 30000/???, loss: 7.265492916107178\n",
"epoch: 0, update in batch 31000/???, loss: 4.921586036682129\n",
"epoch: 0, update in batch 32000/???, loss: 6.467754364013672\n",
"epoch: 0, update in batch 33000/???, loss: 7.393715858459473\n",
"epoch: 0, update in batch 34000/???, loss: 6.9696760177612305\n",
"epoch: 0, update in batch 35000/???, loss: 7.276318550109863\n",
"epoch: 0, update in batch 36000/???, loss: 7.011231899261475\n",
"epoch: 0, update in batch 37000/???, loss: 7.029260158538818\n",
"epoch: 0, update in batch 38000/???, loss: 6.723126411437988\n",
"epoch: 0, update in batch 39000/???, loss: 6.828773498535156\n",
"epoch: 0, update in batch 40000/???, loss: 6.069770336151123\n",
"epoch: 0, update in batch 41000/???, loss: 6.651298522949219\n",
"epoch: 0, update in batch 42000/???, loss: 7.455380916595459\n",
"epoch: 0, update in batch 43000/???, loss: 5.594773769378662\n",
"epoch: 0, update in batch 44000/???, loss: 6.102865219116211\n",
"epoch: 0, update in batch 45000/???, loss: 6.04202127456665\n",
"epoch: 0, update in batch 46000/???, loss: 6.472177982330322\n",
"epoch: 0, update in batch 47000/???, loss: 5.870923042297363\n",
"epoch: 0, update in batch 48000/???, loss: 6.286317348480225\n",
"epoch: 0, update in batch 49000/???, loss: 7.157052516937256\n",
"epoch: 0, update in batch 50000/???, loss: 5.888463020324707\n",
"epoch: 0, update in batch 51000/???, loss: 5.609915733337402\n",
"epoch: 0, update in batch 52000/???, loss: 6.565190315246582\n",
"epoch: 0, update in batch 53000/???, loss: 6.4924468994140625\n",
"epoch: 0, update in batch 54000/???, loss: 6.856420040130615\n",
"epoch: 0, update in batch 55000/???, loss: 7.389428615570068\n",
"epoch: 0, update in batch 56000/???, loss: 5.927685260772705\n",
"epoch: 0, update in batch 57000/???, loss: 7.4227423667907715\n",
"epoch: 0, update in batch 58000/???, loss: 6.46466064453125\n",
"epoch: 0, update in batch 59000/???, loss: 6.586294651031494\n",
"epoch: 0, update in batch 60000/???, loss: 5.797083854675293\n",
"epoch: 0, update in batch 61000/???, loss: 4.825878143310547\n",
"epoch: 0, update in batch 62000/???, loss: 6.911933898925781\n",
"epoch: 0, update in batch 63000/???, loss: 7.684759616851807\n",
"epoch: 0, update in batch 64000/???, loss: 5.716580390930176\n",
"epoch: 0, update in batch 65000/???, loss: 6.1738667488098145\n",
"epoch: 0, update in batch 66000/???, loss: 6.219714164733887\n",
"epoch: 0, update in batch 67000/???, loss: 5.4024128913879395\n",
"epoch: 0, update in batch 68000/???, loss: 6.912312984466553\n",
"epoch: 0, update in batch 69000/???, loss: 6.703289031982422\n",
"epoch: 0, update in batch 70000/???, loss: 7.375630855560303\n",
"epoch: 0, update in batch 71000/???, loss: 5.757082462310791\n",
"epoch: 0, update in batch 72000/???, loss: 5.992405414581299\n",
"epoch: 0, update in batch 73000/???, loss: 6.706838130950928\n",
"epoch: 0, update in batch 74000/???, loss: 7.376870155334473\n",
"epoch: 0, update in batch 75000/???, loss: 6.676860809326172\n",
"epoch: 0, update in batch 76000/???, loss: 5.904101848602295\n",
"epoch: 0, update in batch 77000/???, loss: 6.776932716369629\n",
"epoch: 0, update in batch 78000/???, loss: 5.682181358337402\n",
"epoch: 0, update in batch 79000/???, loss: 6.211178302764893\n",
"epoch: 0, update in batch 80000/???, loss: 6.366950035095215\n",
"epoch: 0, update in batch 81000/???, loss: 5.25206184387207\n",
"epoch: 0, update in batch 82000/???, loss: 6.30997371673584\n",
"epoch: 0, update in batch 83000/???, loss: 6.351908206939697\n",
"epoch: 0, update in batch 84000/???, loss: 7.659114837646484\n",
"epoch: 0, update in batch 85000/???, loss: 6.5041704177856445\n",
"epoch: 0, update in batch 86000/???, loss: 6.770291328430176\n",
"epoch: 0, update in batch 87000/???, loss: 6.530011177062988\n",
"epoch: 0, update in batch 88000/???, loss: 6.317249298095703\n",
"epoch: 0, update in batch 89000/???, loss: 6.191559314727783\n",
"epoch: 0, update in batch 90000/???, loss: 5.79150390625\n",
"epoch: 0, update in batch 91000/???, loss: 6.356796741485596\n",
"epoch: 0, update in batch 92000/???, loss: 7.3577141761779785\n",
"epoch: 0, update in batch 93000/???, loss: 6.529308319091797\n",
"epoch: 0, update in batch 94000/???, loss: 7.740485191345215\n",
"epoch: 0, update in batch 95000/???, loss: 6.348109245300293\n",
"epoch: 0, update in batch 96000/???, loss: 6.032902717590332\n",
"epoch: 0, update in batch 97000/???, loss: 4.505112648010254\n",
"epoch: 0, update in batch 98000/???, loss: 6.946290493011475\n",
"epoch: 0, update in batch 99000/???, loss: 6.237973213195801\n",
"epoch: 0, update in batch 100000/???, loss: 6.963421821594238\n",
"epoch: 0, update in batch 101000/???, loss: 5.309017181396484\n",
"epoch: 0, update in batch 102000/???, loss: 6.242384910583496\n",
"epoch: 0, update in batch 103000/???, loss: 6.8203558921813965\n",
"epoch: 0, update in batch 104000/???, loss: 6.242025852203369\n",
"epoch: 0, update in batch 105000/???, loss: 6.765100002288818\n",
"epoch: 0, update in batch 106000/???, loss: 6.8838043212890625\n",
"epoch: 0, update in batch 107000/???, loss: 6.856662750244141\n",
"epoch: 0, update in batch 108000/???, loss: 6.379549503326416\n",
"epoch: 0, update in batch 109000/???, loss: 6.797707557678223\n",
"epoch: 0, update in batch 110000/???, loss: 7.2699689865112305\n",
"epoch: 0, update in batch 111000/???, loss: 7.040057182312012\n",
"epoch: 0, update in batch 112000/???, loss: 6.7861223220825195\n",
"epoch: 0, update in batch 113000/???, loss: 6.064489364624023\n",
"epoch: 0, update in batch 114000/???, loss: 6.095967769622803\n",
"epoch: 0, update in batch 115000/???, loss: 5.757347106933594\n",
"epoch: 0, update in batch 116000/???, loss: 6.529908657073975\n",
"epoch: 0, update in batch 117000/???, loss: 6.030801296234131\n",
"epoch: 0, update in batch 118000/???, loss: 6.179767608642578\n",
"epoch: 0, update in batch 119000/???, loss: 5.436234474182129\n",
"epoch: 0, update in batch 120000/???, loss: 7.342876434326172\n",
"epoch: 0, update in batch 121000/???, loss: 6.862719535827637\n",
"epoch: 0, update in batch 122000/???, loss: 6.491606712341309\n",
"epoch: 0, update in batch 123000/???, loss: 7.195406436920166\n",
"epoch: 0, update in batch 124000/???, loss: 5.481313228607178\n",
"epoch: 0, update in batch 125000/???, loss: 7.963885307312012\n",
"epoch: 0, update in batch 126000/???, loss: 6.479039669036865\n",
"epoch: 0, update in batch 127000/???, loss: 7.037934303283691\n",
"epoch: 0, update in batch 128000/???, loss: 5.903053283691406\n",
"epoch: 0, update in batch 129000/???, loss: 6.815878391265869\n",
"epoch: 0, update in batch 130000/???, loss: 6.497969150543213\n",
"epoch: 0, update in batch 131000/???, loss: 5.623625755310059\n",
"epoch: 0, update in batch 132000/???, loss: 7.118441104888916\n",
"epoch: 0, update in batch 133000/???, loss: 5.964345455169678\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 134000/???, loss: 6.112139701843262\n",
"epoch: 0, update in batch 135000/???, loss: 6.5865373611450195\n",
"epoch: 0, update in batch 136000/???, loss: 7.498536109924316\n",
"epoch: 0, update in batch 137000/???, loss: 7.124758243560791\n",
"epoch: 0, update in batch 138000/???, loss: 6.871796607971191\n",
"epoch: 0, update in batch 139000/???, loss: 5.8565263748168945\n",
"epoch: 0, update in batch 140000/???, loss: 5.723143577575684\n",
"epoch: 0, update in batch 141000/???, loss: 5.601426124572754\n",
"epoch: 0, update in batch 142000/???, loss: 5.495566368103027\n",
"epoch: 0, update in batch 143000/???, loss: 6.936192989349365\n",
"epoch: 0, update in batch 144000/???, loss: 6.1843671798706055\n",
"epoch: 0, update in batch 145000/???, loss: 6.886034965515137\n",
"epoch: 0, update in batch 146000/???, loss: 6.655320167541504\n",
"epoch: 0, update in batch 147000/???, loss: 6.46828556060791\n",
"epoch: 0, update in batch 148000/???, loss: 5.607057571411133\n",
"epoch: 0, update in batch 149000/???, loss: 7.182212829589844\n",
"epoch: 0, update in batch 150000/???, loss: 7.241323947906494\n",
"epoch: 0, update in batch 151000/???, loss: 7.308540344238281\n",
"epoch: 0, update in batch 152000/???, loss: 5.267911434173584\n",
"epoch: 0, update in batch 153000/???, loss: 5.895949363708496\n",
"epoch: 0, update in batch 154000/???, loss: 6.629178524017334\n",
"epoch: 0, update in batch 155000/???, loss: 4.9156012535095215\n",
"epoch: 0, update in batch 156000/???, loss: 7.181819915771484\n",
"epoch: 0, update in batch 157000/???, loss: 7.438391208648682\n",
"epoch: 0, update in batch 158000/???, loss: 6.406006813049316\n",
"epoch: 0, update in batch 159000/???, loss: 6.486207008361816\n",
"epoch: 0, update in batch 160000/???, loss: 7.041951656341553\n",
"epoch: 0, update in batch 161000/???, loss: 5.310082912445068\n",
"epoch: 0, update in batch 162000/???, loss: 6.9074387550354\n",
"epoch: 0, update in batch 163000/???, loss: 6.644919395446777\n",
"epoch: 0, update in batch 164000/???, loss: 6.011733055114746\n",
"epoch: 0, update in batch 165000/???, loss: 6.494180202484131\n",
"epoch: 0, update in batch 166000/???, loss: 5.390150547027588\n",
"epoch: 0, update in batch 167000/???, loss: 6.627297401428223\n",
"epoch: 0, update in batch 168000/???, loss: 6.9020209312438965\n",
"epoch: 0, update in batch 169000/???, loss: 7.317750453948975\n",
"epoch: 0, update in batch 170000/???, loss: 5.69993782043457\n",
"epoch: 0, update in batch 171000/???, loss: 6.658817291259766\n",
"epoch: 0, update in batch 172000/???, loss: 6.422945976257324\n",
"epoch: 0, update in batch 173000/???, loss: 5.822269439697266\n",
"epoch: 0, update in batch 174000/???, loss: 6.513391017913818\n",
"epoch: 0, update in batch 175000/???, loss: 5.886590957641602\n",
"epoch: 0, update in batch 176000/???, loss: 7.119387149810791\n",
"epoch: 0, update in batch 177000/???, loss: 6.933981418609619\n",
"epoch: 0, update in batch 178000/???, loss: 6.678143501281738\n",
"epoch: 0, update in batch 179000/???, loss: 6.890423774719238\n",
"epoch: 0, update in batch 180000/???, loss: 6.932961940765381\n",
"epoch: 0, update in batch 181000/???, loss: 6.650975704193115\n",
"epoch: 0, update in batch 182000/???, loss: 6.732748985290527\n",
"epoch: 0, update in batch 183000/???, loss: 6.064764976501465\n",
"epoch: 0, update in batch 184000/???, loss: 5.282295227050781\n",
"epoch: 0, update in batch 185000/???, loss: 6.569302558898926\n",
"epoch: 0, update in batch 186000/???, loss: 5.800485610961914\n",
"epoch: 0, update in batch 187000/???, loss: 6.175991058349609\n",
"epoch: 0, update in batch 188000/???, loss: 5.405575752258301\n",
"epoch: 0, update in batch 189000/???, loss: 6.191354751586914\n",
"epoch: 0, update in batch 190000/???, loss: 6.156663417816162\n",
"epoch: 0, update in batch 191000/???, loss: 6.937534332275391\n",
"epoch: 0, update in batch 192000/???, loss: 6.562686920166016\n",
"epoch: 0, update in batch 193000/???, loss: 6.639985084533691\n",
"epoch: 0, update in batch 194000/???, loss: 7.285438537597656\n",
"epoch: 0, update in batch 195000/???, loss: 6.528258323669434\n",
"epoch: 0, update in batch 196000/???, loss: 8.326434135437012\n",
"epoch: 0, update in batch 197000/???, loss: 6.781360626220703\n",
"epoch: 0, update in batch 198000/???, loss: 7.223299980163574\n",
"epoch: 0, update in batch 199000/???, loss: 6.411007881164551\n",
"epoch: 0, update in batch 200000/???, loss: 5.885635852813721\n",
"epoch: 0, update in batch 201000/???, loss: 5.706809043884277\n",
"epoch: 0, update in batch 202000/???, loss: 6.230217933654785\n",
"epoch: 0, update in batch 203000/???, loss: 7.056562900543213\n",
"epoch: 0, update in batch 204000/???, loss: 7.2273077964782715\n",
"epoch: 0, update in batch 205000/???, loss: 6.342462539672852\n",
"epoch: 0, update in batch 206000/???, loss: 6.556817054748535\n",
"epoch: 0, update in batch 207000/???, loss: 5.882349967956543\n",
"epoch: 0, update in batch 208000/???, loss: 6.755805015563965\n",
"epoch: 0, update in batch 209000/???, loss: 6.5045623779296875\n",
"epoch: 0, update in batch 210000/???, loss: 6.525590419769287\n",
"epoch: 0, update in batch 211000/???, loss: 6.49679708480835\n",
"epoch: 0, update in batch 212000/???, loss: 6.562323093414307\n",
"epoch: 0, update in batch 213000/???, loss: 5.227139472961426\n",
"epoch: 0, update in batch 214000/???, loss: 7.044825077056885\n",
"epoch: 0, update in batch 215000/???, loss: 6.002442359924316\n",
"epoch: 0, update in batch 216000/???, loss: 6.084803581237793\n",
"epoch: 0, update in batch 217000/???, loss: 7.425839900970459\n",
"epoch: 0, update in batch 218000/???, loss: 6.818853855133057\n",
"epoch: 0, update in batch 219000/???, loss: 7.0153374671936035\n",
"epoch: 0, update in batch 220000/???, loss: 6.219962120056152\n",
"epoch: 0, update in batch 221000/???, loss: 5.9975385665893555\n",
"epoch: 0, update in batch 222000/???, loss: 6.480047702789307\n",
"epoch: 0, update in batch 223000/???, loss: 6.405727386474609\n",
"epoch: 0, update in batch 224000/???, loss: 4.7763471603393555\n",
"epoch: 0, update in batch 225000/???, loss: 6.615710258483887\n",
"epoch: 0, update in batch 226000/???, loss: 6.385044574737549\n",
"epoch: 0, update in batch 227000/???, loss: 7.260453701019287\n",
"epoch: 0, update in batch 228000/???, loss: 6.9794135093688965\n",
"epoch: 0, update in batch 229000/???, loss: 6.235735893249512\n",
"epoch: 0, update in batch 230000/???, loss: 6.478426456451416\n",
"epoch: 0, update in batch 231000/???, loss: 6.181302547454834\n",
"epoch: 0, update in batch 232000/???, loss: 5.826043128967285\n",
"epoch: 0, update in batch 233000/???, loss: 5.9517011642456055\n",
"epoch: 0, update in batch 234000/???, loss: 8.0064697265625\n",
"epoch: 0, update in batch 235000/???, loss: 6.7822675704956055\n",
"epoch: 0, update in batch 236000/???, loss: 6.293349742889404\n",
"epoch: 0, update in batch 237000/???, loss: 6.442999362945557\n",
"epoch: 0, update in batch 238000/???, loss: 6.282561302185059\n",
"epoch: 0, update in batch 239000/???, loss: 7.166723728179932\n",
"epoch: 0, update in batch 240000/???, loss: 7.189905643463135\n",
"epoch: 0, update in batch 241000/???, loss: 8.462562561035156\n",
"epoch: 0, update in batch 242000/???, loss: 7.446291923522949\n",
"epoch: 0, update in batch 243000/???, loss: 6.382981777191162\n",
"epoch: 0, update in batch 244000/???, loss: 7.635994911193848\n",
"epoch: 0, update in batch 245000/???, loss: 6.635537147521973\n",
"epoch: 0, update in batch 246000/???, loss: 6.068560600280762\n",
"epoch: 0, update in batch 247000/???, loss: 6.193384170532227\n",
"epoch: 0, update in batch 248000/???, loss: 5.702363967895508\n",
"epoch: 0, update in batch 249000/???, loss: 6.09995174407959\n",
"epoch: 0, update in batch 250000/???, loss: 6.312221050262451\n",
"epoch: 0, update in batch 251000/???, loss: 5.853858470916748\n",
"epoch: 0, update in batch 252000/???, loss: 5.886989593505859\n",
"epoch: 0, update in batch 253000/???, loss: 5.801788330078125\n",
"epoch: 0, update in batch 254000/???, loss: 6.032407760620117\n",
"epoch: 0, update in batch 255000/???, loss: 7.480917453765869\n",
"epoch: 0, update in batch 256000/???, loss: 6.578718662261963\n",
"epoch: 0, update in batch 257000/???, loss: 6.344462871551514\n",
"epoch: 0, update in batch 258000/???, loss: 5.939858436584473\n",
"epoch: 0, update in batch 259000/???, loss: 5.181772232055664\n",
"epoch: 0, update in batch 260000/???, loss: 6.640598297119141\n",
"epoch: 0, update in batch 261000/???, loss: 7.189258575439453\n",
"epoch: 0, update in batch 262000/???, loss: 6.2269287109375\n",
"epoch: 0, update in batch 263000/???, loss: 5.8858795166015625\n",
"epoch: 0, update in batch 264000/???, loss: 6.333988666534424\n",
"epoch: 0, update in batch 265000/???, loss: 6.313681602478027\n",
"epoch: 0, update in batch 266000/???, loss: 5.485809803009033\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 267000/???, loss: 6.250800609588623\n",
"epoch: 0, update in batch 268000/???, loss: 6.676806449890137\n",
"epoch: 0, update in batch 269000/???, loss: 5.6487932205200195\n",
"epoch: 0, update in batch 270000/???, loss: 6.648938179016113\n",
"epoch: 0, update in batch 271000/???, loss: 6.26931095123291\n",
"epoch: 0, update in batch 272000/???, loss: 5.343636512756348\n",
"epoch: 0, update in batch 273000/???, loss: 7.051453590393066\n",
"epoch: 0, update in batch 274000/???, loss: 4.578436851501465\n",
"epoch: 0, update in batch 275000/???, loss: 5.400996685028076\n",
"epoch: 0, update in batch 276000/???, loss: 6.129047870635986\n",
"epoch: 0, update in batch 277000/???, loss: 7.549851894378662\n",
"epoch: 0, update in batch 278000/???, loss: 6.093559265136719\n",
"epoch: 0, update in batch 279000/???, loss: 5.6921467781066895\n",
"epoch: 0, update in batch 280000/???, loss: 5.789463996887207\n",
"epoch: 0, update in batch 281000/???, loss: 5.681942939758301\n",
"epoch: 0, update in batch 282000/???, loss: 6.750497341156006\n",
"epoch: 0, update in batch 283000/???, loss: 5.960292339324951\n",
"epoch: 0, update in batch 284000/???, loss: 6.160388469696045\n",
"epoch: 0, update in batch 285000/???, loss: 7.137685298919678\n",
"epoch: 0, update in batch 286000/???, loss: 7.7431464195251465\n",
"epoch: 0, update in batch 287000/???, loss: 5.229738712310791\n",
"epoch: 0, update in batch 288000/???, loss: 6.654232025146484\n",
"epoch: 0, update in batch 289000/???, loss: 6.229329586029053\n",
"epoch: 0, update in batch 290000/???, loss: 7.188180446624756\n",
"epoch: 0, update in batch 291000/???, loss: 6.244111061096191\n",
"epoch: 0, update in batch 292000/???, loss: 7.199154853820801\n",
"epoch: 0, update in batch 293000/???, loss: 7.1866865158081055\n",
"epoch: 0, update in batch 294000/???, loss: 6.574115753173828\n",
"epoch: 0, update in batch 295000/???, loss: 6.487138271331787\n",
"epoch: 0, update in batch 296000/???, loss: 5.813161849975586\n",
"epoch: 0, update in batch 297000/???, loss: 6.159414291381836\n",
"epoch: 0, update in batch 298000/???, loss: 7.256616115570068\n",
"epoch: 0, update in batch 299000/???, loss: 7.511231899261475\n",
"epoch: 0, update in batch 300000/???, loss: 6.148821830749512\n",
"epoch: 0, update in batch 301000/???, loss: 7.108969211578369\n",
"epoch: 0, update in batch 302000/???, loss: 6.528176307678223\n",
"epoch: 0, update in batch 303000/???, loss: 6.276839256286621\n",
"epoch: 0, update in batch 304000/???, loss: 6.484020233154297\n",
"epoch: 0, update in batch 305000/???, loss: 6.38557767868042\n",
"epoch: 0, update in batch 306000/???, loss: 7.068814754486084\n",
"epoch: 0, update in batch 307000/???, loss: 5.844017505645752\n",
"epoch: 0, update in batch 308000/???, loss: 4.25785493850708\n",
"epoch: 0, update in batch 309000/???, loss: 6.709985256195068\n",
"epoch: 0, update in batch 310000/???, loss: 6.543104648590088\n",
"epoch: 0, update in batch 311000/???, loss: 6.675828456878662\n",
"epoch: 0, update in batch 312000/???, loss: 5.82969856262207\n",
"epoch: 0, update in batch 313000/???, loss: 6.05246639251709\n",
"epoch: 0, update in batch 314000/???, loss: 7.2366156578063965\n",
"epoch: 0, update in batch 315000/???, loss: 5.039820194244385\n",
"epoch: 0, update in batch 316000/???, loss: 5.943173885345459\n",
"epoch: 0, update in batch 317000/???, loss: 6.2509002685546875\n",
"epoch: 0, update in batch 318000/???, loss: 6.451228141784668\n",
"epoch: 0, update in batch 319000/???, loss: 6.6381049156188965\n",
"epoch: 0, update in batch 320000/???, loss: 6.570329189300537\n",
"epoch: 0, update in batch 321000/???, loss: 5.376622200012207\n",
"epoch: 0, update in batch 322000/???, loss: 6.487462520599365\n",
"epoch: 0, update in batch 323000/???, loss: 6.676497459411621\n",
"epoch: 0, update in batch 324000/???, loss: 6.283420562744141\n",
"epoch: 0, update in batch 325000/???, loss: 6.164648532867432\n",
"epoch: 0, update in batch 326000/???, loss: 6.839153289794922\n",
"epoch: 0, update in batch 327000/???, loss: 6.435141086578369\n",
"epoch: 0, update in batch 328000/???, loss: 6.160590171813965\n",
"epoch: 0, update in batch 329000/???, loss: 5.876160621643066\n",
"epoch: 0, update in batch 330000/???, loss: 6.47445011138916\n",
"epoch: 0, update in batch 331000/???, loss: 6.294231414794922\n",
"epoch: 0, update in batch 332000/???, loss: 6.099027156829834\n",
"epoch: 0, update in batch 333000/???, loss: 6.986542701721191\n",
"epoch: 0, update in batch 334000/???, loss: 7.018263816833496\n",
"epoch: 0, update in batch 335000/???, loss: 6.906959533691406\n",
"epoch: 0, update in batch 336000/???, loss: 6.12356424331665\n",
"epoch: 0, update in batch 337000/???, loss: 6.316069602966309\n",
"epoch: 0, update in batch 338000/???, loss: 6.908566474914551\n",
"epoch: 0, update in batch 339000/???, loss: 5.628839492797852\n",
"epoch: 0, update in batch 340000/???, loss: 7.069979667663574\n",
"epoch: 0, update in batch 341000/???, loss: 5.350735187530518\n",
"epoch: 0, update in batch 342000/???, loss: 5.377245903015137\n",
"epoch: 0, update in batch 343000/???, loss: 5.2340989112854\n",
"epoch: 0, update in batch 344000/???, loss: 6.087491512298584\n",
"epoch: 0, update in batch 345000/???, loss: 6.162985801696777\n",
"epoch: 0, update in batch 346000/???, loss: 5.655491828918457\n",
"epoch: 0, update in batch 347000/???, loss: 5.311842918395996\n",
"epoch: 0, update in batch 348000/???, loss: 7.577170372009277\n",
"epoch: 0, update in batch 349000/???, loss: 6.730460166931152\n",
"epoch: 0, update in batch 350000/???, loss: 6.782231330871582\n",
"epoch: 0, update in batch 351000/???, loss: 6.789486885070801\n",
"epoch: 0, update in batch 352000/???, loss: 5.473587989807129\n",
"epoch: 0, update in batch 353000/???, loss: 5.531443119049072\n",
"epoch: 0, update in batch 354000/???, loss: 7.220989227294922\n",
"epoch: 0, update in batch 355000/???, loss: 5.954288005828857\n",
"epoch: 0, update in batch 356000/???, loss: 4.112783432006836\n",
"epoch: 0, update in batch 357000/???, loss: 5.409672737121582\n",
"epoch: 0, update in batch 358000/???, loss: 6.408724784851074\n",
"epoch: 0, update in batch 359000/???, loss: 6.744941711425781\n",
"epoch: 0, update in batch 360000/???, loss: 6.218225479125977\n",
"epoch: 0, update in batch 361000/???, loss: 6.071394920349121\n",
"epoch: 0, update in batch 362000/???, loss: 6.137121677398682\n",
"epoch: 0, update in batch 363000/???, loss: 5.876864433288574\n",
"epoch: 0, update in batch 364000/???, loss: 7.715426445007324\n",
"epoch: 0, update in batch 365000/???, loss: 6.217362880706787\n",
"epoch: 0, update in batch 366000/???, loss: 6.741396903991699\n",
"epoch: 0, update in batch 367000/???, loss: 6.4564313888549805\n",
"epoch: 0, update in batch 368000/???, loss: 6.994439601898193\n",
"epoch: 0, update in batch 369000/???, loss: 6.061278820037842\n",
"epoch: 0, update in batch 370000/???, loss: 4.894576549530029\n",
"epoch: 0, update in batch 371000/???, loss: 6.351264953613281\n",
"epoch: 0, update in batch 372000/???, loss: 6.826904296875\n",
"epoch: 0, update in batch 373000/???, loss: 6.090312480926514\n",
"epoch: 0, update in batch 374000/???, loss: 5.797528266906738\n",
"epoch: 0, update in batch 375000/???, loss: 7.3235673904418945\n",
"epoch: 0, update in batch 376000/???, loss: 5.5752973556518555\n",
"epoch: 0, update in batch 377000/???, loss: 6.29438591003418\n",
"epoch: 0, update in batch 378000/???, loss: 5.238917827606201\n",
"epoch: 0, update in batch 379000/???, loss: 5.542972564697266\n",
"epoch: 0, update in batch 380000/???, loss: 6.5024614334106445\n",
"epoch: 0, update in batch 381000/???, loss: 6.918997287750244\n",
"epoch: 0, update in batch 382000/???, loss: 5.694029331207275\n",
"epoch: 0, update in batch 383000/???, loss: 7.109190940856934\n",
"epoch: 0, update in batch 384000/???, loss: 5.214654445648193\n",
"epoch: 0, update in batch 385000/???, loss: 7.055975437164307\n",
"epoch: 0, update in batch 386000/???, loss: 6.443846225738525\n",
"epoch: 0, update in batch 387000/???, loss: 5.544674873352051\n",
"epoch: 0, update in batch 388000/???, loss: 6.936171531677246\n",
"epoch: 0, update in batch 389000/???, loss: 6.646860599517822\n",
"epoch: 0, update in batch 390000/???, loss: 6.193584442138672\n",
"epoch: 0, update in batch 391000/???, loss: 5.9077558517456055\n",
"epoch: 0, update in batch 392000/???, loss: 5.029908657073975\n",
"epoch: 0, update in batch 393000/???, loss: 6.725222587585449\n",
"epoch: 0, update in batch 394000/???, loss: 6.814855098724365\n",
"epoch: 0, update in batch 395000/???, loss: 7.396543979644775\n",
"epoch: 0, update in batch 396000/???, loss: 6.993375301361084\n",
"epoch: 0, update in batch 397000/???, loss: 6.224326133728027\n",
"epoch: 0, update in batch 398000/???, loss: 6.301025390625\n",
"epoch: 0, update in batch 399000/???, loss: 6.707190036773682\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 400000/???, loss: 6.7646660804748535\n",
"epoch: 0, update in batch 401000/???, loss: 5.308177471160889\n",
"epoch: 0, update in batch 402000/???, loss: 8.35996150970459\n",
"epoch: 0, update in batch 403000/???, loss: 5.825610160827637\n",
"epoch: 0, update in batch 404000/???, loss: 6.310220718383789\n",
"epoch: 0, update in batch 405000/???, loss: 5.759210109710693\n",
"epoch: 0, update in batch 406000/???, loss: 6.32699728012085\n",
"epoch: 0, update in batch 407000/???, loss: 5.659378528594971\n",
"epoch: 0, update in batch 408000/???, loss: 6.216103553771973\n",
"epoch: 0, update in batch 409000/???, loss: 5.666914463043213\n",
"epoch: 0, update in batch 410000/???, loss: 6.419122219085693\n",
"epoch: 0, update in batch 411000/???, loss: 5.372750282287598\n",
"epoch: 0, update in batch 412000/???, loss: 6.839580535888672\n",
"epoch: 0, update in batch 413000/???, loss: 6.7682647705078125\n",
"epoch: 0, update in batch 414000/???, loss: 5.951648235321045\n",
"epoch: 0, update in batch 415000/???, loss: 6.181953430175781\n",
"epoch: 0, update in batch 416000/???, loss: 5.475704669952393\n",
"epoch: 0, update in batch 417000/???, loss: 6.383082866668701\n",
"epoch: 0, update in batch 418000/???, loss: 6.8107590675354\n",
"epoch: 0, update in batch 419000/???, loss: 5.753104209899902\n",
"epoch: 0, update in batch 420000/???, loss: 5.320840835571289\n",
"epoch: 0, update in batch 421000/???, loss: 7.377203464508057\n",
"epoch: 0, update in batch 422000/???, loss: 6.5706048011779785\n",
"epoch: 0, update in batch 423000/???, loss: 5.032872676849365\n",
"epoch: 0, update in batch 424000/???, loss: 5.781243324279785\n",
"epoch: 0, update in batch 425000/???, loss: 6.160118579864502\n"
]
}
],
"source": [
"train(train_dataset, model, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "36a1b802",
"metadata": {},
"outputs": [],
"source": [
"def predict_probs(tokens):\n",
" model.eval()\n",
" words = text.split(' ')\n",
" state_h, state_c = model.init_state(len(words))\n",
" state_h = model.init_state(len(tokens))\n",
"\n",
" for i in range(0, next_words):\n",
" x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n",
" y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n",
" x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)\n",
" y_pred, state_h = model(x)\n",
"\n",
" last_word_logits = y_pred[0][-1]\n",
" p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=p)\n",
" words.append(dataset.index_to_word[word_index])\n",
"\n",
" return words"
" last_word_logits = y_pred[0][-1]\n",
" probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=probs)\n",
" \n",
" top_words = []\n",
" for index in range(len(probs)):\n",
" if len(top_words) < 30:\n",
" top_words.append((probs[index], [index]))\n",
" else:\n",
" worst_word = None\n",
" for word in top_words:\n",
" if not worst_word:\n",
" worst_word = word\n",
" else:\n",
" if word[0] < worst_word[0]:\n",
" worst_word = word\n",
" if worst_word[0] < probs[index] and index != len(probs) - 1:\n",
" top_words.remove(worst_word)\n",
" top_words.append((probs[index], [index]))\n",
" \n",
" prediction = ''\n",
" sum_prob = 0.0\n",
" for word in top_words:\n",
" sum_prob += word[0]\n",
" word_index = word[0]\n",
" word_text = index_to_key[word[1][0]]\n",
" prediction += f'{word_text}:{word_index} '\n",
" prediction += f':{1 - sum_prob}'\n",
" \n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcd99ea4",
"execution_count": 56,
"id": "155636b5",
"metadata": {},
"outputs": [],
"source": [
"train(data, model, 10, 128)"
"dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df052344",
"execution_count": 59,
"id": "99b1d944",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), 'model.pt')"
"with open('dev-0/out.tsv', 'w') as file:\n",
" for index, row in dev_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" left_words = word_tokenize(left_text)\n",
" if len(left_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:])\n",
" file.write(prediction + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e87339b",
"execution_count": 60,
"id": "186c3269",
"metadata": {},
"outputs": [],
"source": []
"source": [
"with open('test-A/out.tsv', 'w') as file:\n",
" for index, row in test_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" left_words = word_tokenize(left_text)\n",
" if len(left_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:])\n",
" file.write(prediction + '\\n')"
]
}
],
"metadata": {

259
run.py Normal file
View File

@ -0,0 +1,259 @@
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
# In[2]:
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# In[3]:
def clean_text(text):
text = text.lower().replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ')
text = re.sub(r'\p{P}', '', text)
text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
return text
# In[4]:
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
# In[5]:
class TrainCorpus:
def __init__(self, data):
self.data = data
def __iter__(self):
for _, row in self.data.iterrows():
text = str(row[6]) + str(row[0]) + str(row[7])
text = clean_text(text)
yield word_tokenize(text)
# In[6]:
train_sentences = TrainCorpus(train_data.head(100000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
# In[7]:
w2v_model.build_vocab(corpus_iterable=train_sentences)
key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key
index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1
vocab_size = len(index_to_key)
print(vocab_size)
# In[8]:
class TrainDataset(torch.utils.data.IterableDataset):
def __init__(self, data, index_to_key, key_to_index):
self.data = data
self.index_to_key = index_to_key
self.key_to_index = key_to_index
self.vocab_size = len(key_to_index)
def __iter__(self):
for _, row in self.data.iterrows():
text = str(row[6]) + str(row[0]) + str(row[7])
text = clean_text(text)
tokens = word_tokenize(text)
for i in range(5, len(tokens), 1):
input_context = tokens[i-5:i]
target_context = tokens[i-4:i+1]
#gap_word = tokens[i]
input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
#word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
#word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
# In[9]:
class Model(nn.Module):
def __init__(self, embed_size, vocab_size):
super(Model, self).__init__()
self.embed_size = embed_size
self.vocab_size = vocab_size
self.gru_size = 128
self.num_layers = 2
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)
self.fc = nn.Linear(self.gru_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embed(x)
output, state = self.gru(embed, prev_state)
logits = self.fc(output)
probs = torch.softmax(logits, dim=1)
return logits, state
def init_state(self, sequence_length):
zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
return (zeros, zeros)
# In[10]:
from torch.utils.data import DataLoader
from torch.optim import Adam
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, state_h = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
if batch % 1000 == 0:
print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
# In[11]:
train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)
# In[12]:
model = Model(100, vocab_size).to(device)
# In[13]:
train(train_dataset, model, 1, 64)
# In[58]:
def predict_probs(tokens):
model.eval()
state_h = model.init_state(len(tokens))
x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)
y_pred, state_h = model(x)
last_word_logits = y_pred[0][-1]
probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
word_index = np.random.choice(len(last_word_logits), p=probs)
top_words = []
for index in range(len(probs)):
if len(top_words) < 30:
top_words.append((probs[index], [index]))
else:
worst_word = None
for word in top_words:
if not worst_word:
worst_word = word
else:
if word[0] < worst_word[0]:
worst_word = word
if worst_word[0] < probs[index] and index != len(probs) - 1:
top_words.remove(worst_word)
top_words.append((probs[index], [index]))
prediction = ''
sum_prob = 0.0
for word in top_words:
sum_prob += word[0]
word_index = word[0]
word_text = index_to_key[word[1][0]]
prediction += f'{word_text}:{word_index} '
prediction += f':{1 - sum_prob}'
return prediction
# In[56]:
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
# In[59]:
with open('dev-0/out.tsv', 'w') as file:
for index, row in dev_data.iterrows():
left_text = clean_text(str(row[6]))
left_words = word_tokenize(left_text)
if len(left_words) < 6:
prediction = ':1.0'
else:
prediction = predict_probs(left_words[-5:])
file.write(prediction + '\n')
# In[60]:
with open('test-A/out.tsv', 'w') as file:
for index, row in test_data.iterrows():
left_text = clean_text(str(row[6]))
left_words = word_tokenize(left_text)
if len(left_words) < 6:
prediction = ':1.0'
else:
prediction = predict_probs(left_words[-5:])
file.write(prediction + '\n')

7414
test-A/out.tsv Normal file

File diff suppressed because it is too large Load Diff