This commit is contained in:
Jakub Pokrywka 2022-05-30 09:17:21 +02:00
parent 09846e1396
commit ecf931a3e0

View File

@ -250,12 +250,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"class EncoderRNN(nn.Module):\n", "class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n", " def __init__(self, input_size, embedding_size, hidden_size):\n",
" super(EncoderRNN, self).__init__()\n", " super(EncoderRNN, self).__init__()\n",
" self.embedding_size = 200\n",
" self.hidden_size = hidden_size\n", " self.hidden_size = hidden_size\n",
"\n", "\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\n", " self.embedding = nn.Embedding(input_size, self.embedding_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\n", " self.gru = nn.GRU(self.embedding_size, hidden_size)\n",
"\n", "\n",
" def forward(self, input, hidden):\n", " def forward(self, input, hidden):\n",
" embedded = self.embedding(input).view(1, 1, -1)\n", " embedded = self.embedding(input).view(1, 1, -1)\n",
@ -274,12 +275,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"class DecoderRNN(nn.Module):\n", "class DecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n", " def __init__(self, embedding_size, hidden_size, output_size):\n",
" super(DecoderRNN, self).__init__()\n", " super(DecoderRNN, self).__init__()\n",
" self.embedding_size = embedding_size\n",
" self.hidden_size = hidden_size\n", " self.hidden_size = hidden_size\n",
"\n", "\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n", " self.embedding = nn.Embedding(output_size, self.embedding_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\n", " self.gru = nn.GRU(self.embedding_size, hidden_size)\n",
" self.out = nn.Linear(hidden_size, output_size)\n", " self.out = nn.Linear(hidden_size, output_size)\n",
" self.softmax = nn.LogSoftmax(dim=1)\n", " self.softmax = nn.LogSoftmax(dim=1)\n",
"\n", "\n",
@ -301,18 +303,19 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"class AttnDecoderRNN(nn.Module):\n", "class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n", " def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n",
" super(AttnDecoderRNN, self).__init__()\n", " super(AttnDecoderRNN, self).__init__()\n",
" self.embedding_size = embedding_size\n",
" self.hidden_size = hidden_size\n", " self.hidden_size = hidden_size\n",
" self.output_size = output_size\n", " self.output_size = output_size\n",
" self.dropout_p = dropout_p\n", " self.dropout_p = dropout_p\n",
" self.max_length = max_length\n", " self.max_length = max_length\n",
"\n", "\n",
" self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n", " self.embedding = nn.Embedding(self.output_size, self.embedding_size)\n",
" self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n", " self.attn = nn.Linear(self.hidden_size + self.embedding_size, self.max_length)\n",
" self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n", " self.attn_combine = nn.Linear(self.hidden_size + self.embedding_size, self.embedding_size)\n",
" self.dropout = nn.Dropout(self.dropout_p)\n", " self.dropout = nn.Dropout(self.dropout_p)\n",
" self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n", " self.gru = nn.GRU(self.embedding_size, self.hidden_size)\n",
" self.out = nn.Linear(self.hidden_size, self.output_size)\n", " self.out = nn.Linear(self.hidden_size, self.output_size)\n",
"\n", "\n",
" def forward(self, input, hidden, encoder_outputs):\n", " def forward(self, input, hidden, encoder_outputs):\n",
@ -323,6 +326,7 @@
" self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n", " self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n",
" attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n", " attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n",
" encoder_outputs.unsqueeze(0))\n", " encoder_outputs.unsqueeze(0))\n",
" #import pdb; pdb.set_trace()\n",
"\n", "\n",
" output = torch.cat((embedded[0], attn_applied[0]), 1)\n", " output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
" output = self.attn_combine(output).unsqueeze(0)\n", " output = self.attn_combine(output).unsqueeze(0)\n",
@ -508,9 +512,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"embedding_size = 200\n",
"hidden_size = 256\n", "hidden_size = 256\n",
"encoder1 = EncoderRNN(eng_lang.n_words, hidden_size).to(device)\n", "encoder1 = EncoderRNN(eng_lang.n_words, embedding_size, hidden_size).to(device)\n",
"attn_decoder1 = AttnDecoderRNN(hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)" "attn_decoder1 = AttnDecoderRNN(embedding_size, hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)"
] ]
}, },
{ {
@ -522,206 +527,206 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"iter: 50, loss: 4.699110437711081\n", "iter: 50, loss: 5.042555550272503\n",
"iter: 100, loss: 4.241607124086411\n", "iter: 100, loss: 4.143612308138894\n",
"iter: 150, loss: 4.14866822333563\n", "iter: 150, loss: 4.258466395877656\n",
"iter: 200, loss: 4.175457921709334\n", "iter: 200, loss: 4.078979822052849\n",
"iter: 250, loss: 4.304153789429438\n", "iter: 250, loss: 3.9038650802657715\n",
"iter: 300, loss: 4.304717092377798\n", "iter: 300, loss: 4.07207449336279\n",
"iter: 350, loss: 4.316578052808368\n", "iter: 350, loss: 3.940484183538527\n",
"iter: 400, loss: 4.379952565056937\n", "iter: 400, loss: 4.425489738524906\n",
"iter: 450, loss: 4.086811531929743\n", "iter: 450, loss: 3.9398847290826224\n",
"iter: 500, loss: 4.252370147765628\n", "iter: 500, loss: 4.264409653027852\n",
"iter: 550, loss: 4.02257244164603\n", "iter: 550, loss: 4.323172234974209\n",
"iter: 600, loss: 4.271288591505989\n", "iter: 600, loss: 4.22224827657427\n",
"iter: 650, loss: 4.037527732379852\n", "iter: 650, loss: 4.204052018634857\n",
"iter: 700, loss: 3.808401109422956\n", "iter: 700, loss: 3.9438682432023295\n",
"iter: 750, loss: 4.01287091629089\n", "iter: 750, loss: 4.001692515509468\n",
"iter: 800, loss: 4.185342459905715\n", "iter: 800, loss: 4.054982795352028\n",
"iter: 850, loss: 3.8268170519934763\n", "iter: 850, loss: 4.119050166281443\n",
"iter: 900, loss: 3.9197384970074607\n", "iter: 900, loss: 3.908679961704073\n",
"iter: 950, loss: 4.225208856279888\n", "iter: 950, loss: 4.136870030266898\n",
"iter: 1000, loss: 4.128686094178094\n", "iter: 1000, loss: 3.8147727276938297\n",
"iter: 1050, loss: 3.9167927505553712\n", "iter: 1050, loss: 4.026022962623171\n",
"iter: 1100, loss: 4.015269571940103\n", "iter: 1100, loss: 3.9598817706335154\n",
"iter: 1150, loss: 4.168424199830918\n", "iter: 1150, loss: 3.848097898089696\n",
"iter: 1200, loss: 4.302581990559896\n", "iter: 1200, loss: 4.01016833985041\n",
"iter: 1250, loss: 3.7335942743392225\n", "iter: 1250, loss: 3.7720014858472917\n",
"iter: 1300, loss: 3.9526881422315334\n", "iter: 1300, loss: 4.059876484976874\n",
"iter: 1350, loss: 3.8640213389169604\n", "iter: 1350, loss: 3.8380891363658605\n",
"iter: 1400, loss: 4.101886716827512\n", "iter: 1400, loss: 4.013203263676356\n",
"iter: 1450, loss: 3.6106392740067985\n", "iter: 1450, loss: 4.067137318686833\n",
"iter: 1500, loss: 4.0689067233857665\n", "iter: 1500, loss: 4.020450985673874\n",
"iter: 1550, loss: 4.02288844353812\n", "iter: 1550, loss: 3.7160321428662244\n",
"iter: 1600, loss: 3.572508715992883\n", "iter: 1600, loss: 3.8411714478977137\n",
"iter: 1650, loss: 3.972692446489183\n", "iter: 1650, loss: 3.7125136051177985\n",
"iter: 1700, loss: 3.8709554294404525\n", "iter: 1700, loss: 3.705152728769514\n",
"iter: 1750, loss: 3.9830204631714583\n", "iter: 1750, loss: 3.9118153427441915\n",
"iter: 1800, loss: 3.7999766263961794\n", "iter: 1800, loss: 3.857195938375262\n",
"iter: 1850, loss: 3.7026816112578858\n", "iter: 1850, loss: 3.9566935270703025\n",
"iter: 1900, loss: 3.833205360775902\n", "iter: 1900, loss: 3.9394864430957375\n",
"iter: 1950, loss: 3.650638633606925\n", "iter: 1950, loss: 3.636212232317243\n",
"iter: 2000, loss: 3.748746382418133\n", "iter: 2000, loss: 3.847666795261321\n",
"iter: 2050, loss: 3.762590566922748\n", "iter: 2050, loss: 3.787096965411352\n",
"iter: 2100, loss: 3.5997376789214117\n", "iter: 2100, loss: 3.4702608700933912\n",
"iter: 2150, loss: 3.919283335610041\n", "iter: 2150, loss: 3.727882717624543\n",
"iter: 2200, loss: 3.8638847478684912\n", "iter: 2200, loss: 3.6961711362884153\n",
"iter: 2250, loss: 3.4960837801675946\n", "iter: 2250, loss: 3.870331466848889\n",
"iter: 2300, loss: 3.685049927688782\n", "iter: 2300, loss: 3.8506508341743837\n",
"iter: 2350, loss: 3.5716699722759313\n", "iter: 2350, loss: 3.803002176814609\n",
"iter: 2400, loss: 3.8988636863874997\n", "iter: 2400, loss: 3.5700957290558586\n",
"iter: 2450, loss: 3.752788569586617\n", "iter: 2450, loss: 3.5328896935326712\n",
"iter: 2500, loss: 3.802307117961702\n", "iter: 2500, loss: 3.810194352997674\n",
"iter: 2550, loss: 3.6420236970432227\n", "iter: 2550, loss: 3.713556599700262\n",
"iter: 2600, loss: 3.6925315249912325\n", "iter: 2600, loss: 3.6131167711303345\n",
"iter: 2650, loss: 3.8897219879059572\n", "iter: 2650, loss: 3.433012700254954\n",
"iter: 2700, loss: 3.6327851654537153\n", "iter: 2700, loss: 3.7313271602903084\n",
"iter: 2750, loss: 3.396957855118645\n", "iter: 2750, loss: 3.5837062497366037\n",
"iter: 2800, loss: 3.5258935768112307\n", "iter: 2800, loss: 3.6265894929265214\n",
"iter: 2850, loss: 3.605109554866003\n", "iter: 2850, loss: 3.5165250884616186\n",
"iter: 2900, loss: 3.533288128330594\n", "iter: 2900, loss: 3.8752988719410366\n",
"iter: 2950, loss: 3.4583421086054\n", "iter: 2950, loss: 3.709828086020455\n",
"iter: 3000, loss: 3.403592811425526\n", "iter: 3000, loss: 3.742527751090035\n",
"iter: 3050, loss: 3.5225157889411567\n", "iter: 3050, loss: 3.5926183513232646\n",
"iter: 3100, loss: 3.4702517202846592\n", "iter: 3100, loss: 3.6629667194003157\n",
"iter: 3150, loss: 3.4234997159185867\n", "iter: 3150, loss: 3.7953110780715944\n",
"iter: 3200, loss: 3.5447632862348404\n", "iter: 3200, loss: 3.4833724756770663\n",
"iter: 3250, loss: 3.1799173504133074\n", "iter: 3250, loss: 3.5239689500066977\n",
"iter: 3300, loss: 3.7154814013905\n", "iter: 3300, loss: 3.552185758560423\n",
"iter: 3350, loss: 3.4188442155444445\n", "iter: 3350, loss: 3.342997217700594\n",
"iter: 3400, loss: 3.6557525696527393\n", "iter: 3400, loss: 3.7131163925897512\n",
"iter: 3450, loss: 3.52880564416401\n", "iter: 3450, loss: 3.2172264359110874\n",
"iter: 3500, loss: 3.4842312318408295\n", "iter: 3500, loss: 3.1694674255961464\n",
"iter: 3550, loss: 3.5256399853570115\n", "iter: 3550, loss: 3.5181667824548386\n",
"iter: 3600, loss: 3.70226228499034\n", "iter: 3600, loss: 3.552696303821745\n",
"iter: 3650, loss: 3.2043497113424633\n", "iter: 3650, loss: 3.5465369727573703\n",
"iter: 3700, loss: 3.4575287022439256\n", "iter: 3700, loss: 3.3895190108844213\n",
"iter: 3750, loss: 3.4197605448374664\n", "iter: 3750, loss: 3.55357305569119\n",
"iter: 3800, loss: 3.290345760890417\n", "iter: 3800, loss: 3.618841464133489\n",
"iter: 3850, loss: 3.300158274309976\n", "iter: 3850, loss: 3.631707963504488\n",
"iter: 3900, loss: 3.3362661438139645\n", "iter: 3900, loss: 3.705602922939119\n",
"iter: 3950, loss: 3.4947717628630373\n", "iter: 3950, loss: 3.1555525365556987\n",
"iter: 4000, loss: 3.5624450731353154\n", "iter: 4000, loss: 3.423284879676879\n",
"iter: 4050, loss: 3.438600626892514\n", "iter: 4050, loss: 3.74216214027859\n",
"iter: 4100, loss: 3.142976412258451\n", "iter: 4100, loss: 3.273874522224304\n",
"iter: 4150, loss: 3.332818130595344\n", "iter: 4150, loss: 3.9754231488666836\n",
"iter: 4200, loss: 3.31952378733196\n", "iter: 4200, loss: 3.255707532473973\n",
"iter: 4250, loss: 3.5315058948123252\n", "iter: 4250, loss: 3.622867019956075\n",
"iter: 4300, loss: 3.6603812535074023\n", "iter: 4300, loss: 3.3847267730198216\n",
"iter: 4350, loss: 3.35295347692853\n", "iter: 4350, loss: 3.6832511274095565\n",
"iter: 4400, loss: 3.374297706498041\n", "iter: 4400, loss: 3.265418997968946\n",
"iter: 4450, loss: 3.09948105843105\n", "iter: 4450, loss: 3.53306358509972\n",
"iter: 4500, loss: 3.16787886763376\n", "iter: 4500, loss: 3.2655868359520333\n",
"iter: 4550, loss: 3.455794033330583\n", "iter: 4550, loss: 3.579948601419965\n",
"iter: 4600, loss: 3.1263191164258926\n", "iter: 4600, loss: 3.554656519799005\n",
"iter: 4650, loss: 3.3723485524995\n", "iter: 4650, loss: 3.324159849643708\n",
"iter: 4700, loss: 3.147410953930445\n", "iter: 4700, loss: 3.357913894865249\n",
"iter: 4750, loss: 3.4546711923281346\n", "iter: 4750, loss: 3.048288846031067\n",
"iter: 4800, loss: 3.449277176016852\n", "iter: 4800, loss: 3.185154194937811\n",
"iter: 4850, loss: 3.197799104531606\n", "iter: 4850, loss: 2.9646709245159513\n",
"iter: 4900, loss: 3.239384971149383\n", "iter: 4900, loss: 3.4766449508288546\n",
"iter: 4950, loss: 3.696369633697328\n", "iter: 4950, loss: 3.1528075372302338\n",
"iter: 5000, loss: 3.2114706332191587\n", "iter: 5000, loss: 3.12558690051427\n",
"iter: 5050, loss: 3.400943172795432\n", "iter: 5050, loss: 3.6565875165273276\n",
"iter: 5100, loss: 3.298932059106372\n", "iter: 5100, loss: 3.113538140228817\n",
"iter: 5150, loss: 3.3697974183445907\n", "iter: 5150, loss: 3.0463946421638366\n",
"iter: 5200, loss: 3.31293656670858\n", "iter: 5200, loss: 3.384180574084086\n",
"iter: 5250, loss: 3.1415378823658773\n", "iter: 5250, loss: 3.3104316232090913\n",
"iter: 5300, loss: 3.1587839283867494\n", "iter: 5300, loss: 2.9496352179807332\n",
"iter: 5350, loss: 3.3505903312440903\n", "iter: 5350, loss: 3.1814023027722804\n",
"iter: 5400, loss: 3.247191356802744\n", "iter: 5400, loss: 2.9286732437345724\n",
"iter: 5450, loss: 3.236625145200699\n", "iter: 5450, loss: 3.4691178646617464\n",
"iter: 5500, loss: 3.19994143747148\n", "iter: 5500, loss: 3.373944672122834\n",
"iter: 5550, loss: 3.2911239544626265\n", "iter: 5550, loss: 3.213332776455653\n",
"iter: 5600, loss: 3.1855649600483122\n", "iter: 5600, loss: 3.3247368506931116\n",
"iter: 5650, loss: 3.157031875163789\n", "iter: 5650, loss: 3.2702379176957272\n",
"iter: 5700, loss: 3.2652817099586366\n", "iter: 5700, loss: 3.4554740653038025\n",
"iter: 5750, loss: 3.3272896775593837\n", "iter: 5750, loss: 3.281306777431851\n",
"iter: 5800, loss: 3.3162626687458583\n", "iter: 5800, loss: 2.9936736260368706\n",
"iter: 5850, loss: 3.1342987139338536\n", "iter: 5850, loss: 3.277740831851959\n",
"iter: 5900, loss: 3.29665669613036\n", "iter: 5900, loss: 3.120459364088754\n",
"iter: 5950, loss: 3.232995939807286\n", "iter: 5950, loss: 3.387252744160001\n",
"iter: 6000, loss: 3.0922561403758935\n", "iter: 6000, loss: 3.238504883735898\n",
"iter: 6050, loss: 3.1034776155835107\n", "iter: 6050, loss: 2.738152531003195\n",
"iter: 6100, loss: 3.1502840874081564\n", "iter: 6100, loss: 3.231002421265556\n",
"iter: 6150, loss: 2.915993771098909\n", "iter: 6150, loss: 3.0410601262819195\n",
"iter: 6200, loss: 2.994096033270397\n", "iter: 6200, loss: 3.093445486522856\n",
"iter: 6250, loss: 3.1102042265392487\n", "iter: 6250, loss: 2.877119398207891\n",
"iter: 6300, loss: 2.8244728108587718\n", "iter: 6300, loss: 3.006740029849703\n",
"iter: 6350, loss: 3.117810124692462\n", "iter: 6350, loss: 2.8918780979504657\n",
"iter: 6400, loss: 3.0742526639529637\n", "iter: 6400, loss: 3.3124666434015553\n",
"iter: 6450, loss: 2.8390014954218787\n", "iter: 6450, loss: 3.170363757602752\n",
"iter: 6500, loss: 3.1032223067510687\n", "iter: 6500, loss: 3.1445780278387527\n",
"iter: 6550, loss: 2.912433739840038\n", "iter: 6550, loss: 3.0042706321610346\n",
"iter: 6600, loss: 2.9158696003490023\n", "iter: 6600, loss: 2.94450242013023\n",
"iter: 6650, loss: 3.2617745389030093\n", "iter: 6650, loss: 3.1747314814840046\n",
"iter: 6700, loss: 3.295657290466248\n", "iter: 6700, loss: 3.325715871651966\n",
"iter: 6750, loss: 2.975928121767347\n", "iter: 6750, loss: 3.1039765825120225\n",
"iter: 6800, loss: 3.0057779382069914\n", "iter: 6800, loss: 3.260562201068516\n",
"iter: 6850, loss: 2.85224422507059\n", "iter: 6850, loss: 2.95558365320024\n",
"iter: 6900, loss: 3.0329934195336836\n", "iter: 6900, loss: 3.1284036347071327\n",
"iter: 6950, loss: 3.1322296761255415\n", "iter: 6950, loss: 3.161784927746607\n",
"iter: 7000, loss: 2.893814939192363\n", "iter: 7000, loss: 3.083566860369275\n",
"iter: 7050, loss: 2.934597730205173\n", "iter: 7050, loss: 3.1606678485643296\n",
"iter: 7100, loss: 3.267660904082041\n", "iter: 7100, loss: 3.39304134529356\n",
"iter: 7150, loss: 3.1199153114651867\n", "iter: 7150, loss: 3.05389289476001\n",
"iter: 7200, loss: 2.8414319788160776\n", "iter: 7200, loss: 3.171286074725408\n",
"iter: 7250, loss: 3.1128779797251256\n", "iter: 7250, loss: 3.307133579034654\n",
"iter: 7300, loss: 3.1182169116565155\n", "iter: 7300, loss: 2.987511603022379\n",
"iter: 7350, loss: 3.101384938853128\n", "iter: 7350, loss: 3.1221464098370264\n",
"iter: 7400, loss: 2.9836614183395627\n", "iter: 7400, loss: 2.9686622249966574\n",
"iter: 7450, loss: 2.7261425285036602\n", "iter: 7450, loss: 2.874706161885035\n",
"iter: 7500, loss: 2.7323913456977356\n", "iter: 7500, loss: 2.759323406164608\n",
"iter: 7550, loss: 3.284201001443559\n", "iter: 7550, loss: 2.835318256658221\n",
"iter: 7600, loss: 2.9473503636405587\n", "iter: 7600, loss: 2.896953154404958\n",
"iter: 7650, loss: 2.861012626541986\n", "iter: 7650, loss: 2.8871691599497717\n",
"iter: 7700, loss: 2.6726747900872003\n", "iter: 7700, loss: 3.049550093332927\n",
"iter: 7750, loss: 2.760957624162947\n", "iter: 7750, loss: 2.9703013692507665\n",
"iter: 7800, loss: 2.647666095211393\n", "iter: 7800, loss: 2.8142153175671893\n",
"iter: 7850, loss: 2.7921250426428657\n", "iter: 7850, loss: 2.8352768955987604\n",
"iter: 7900, loss: 2.9527213778495787\n", "iter: 7900, loss: 2.863677294496506\n",
"iter: 7950, loss: 2.790506172891647\n", "iter: 7950, loss: 3.031682641491057\n",
"iter: 8000, loss: 2.8376009529431663\n", "iter: 8000, loss: 2.9286883136809814\n",
"iter: 8050, loss: 3.0387913953690298\n", "iter: 8050, loss: 2.9240697879488504\n",
"iter: 8100, loss: 2.908381733046637\n", "iter: 8100, loss: 3.0172221147900546\n",
"iter: 8150, loss: 2.7374484727761104\n", "iter: 8150, loss: 2.8361169849426027\n",
"iter: 8200, loss: 2.84610585779614\n", "iter: 8200, loss: 2.9860127468676803\n",
"iter: 8250, loss: 2.8532650649736793\n", "iter: 8250, loss: 2.9495567634294906\n",
"iter: 8300, loss: 2.856347685723078\n", "iter: 8300, loss: 2.793946119104113\n",
"iter: 8350, loss: 2.6641267998710503\n", "iter: 8350, loss: 3.2106793221594785\n",
"iter: 8400, loss: 2.7541870554590973\n", "iter: 8400, loss: 2.736634517018757\n",
"iter: 8450, loss: 2.814719854824126\n", "iter: 8450, loss: 2.8962079345536615\n",
"iter: 8500, loss: 2.6979909611694395\n", "iter: 8500, loss: 2.906407202516283\n",
"iter: 8550, loss: 2.577483120327904\n", "iter: 8550, loss: 2.6900012663281148\n",
"iter: 8600, loss: 2.7884950113561415\n", "iter: 8600, loss: 2.8905927643056897\n",
"iter: 8650, loss: 3.0236114144552317\n", "iter: 8650, loss: 2.950769727600945\n",
"iter: 8700, loss: 2.5850161893329924\n", "iter: 8700, loss: 2.884238138978443\n",
"iter: 8750, loss: 2.992550043756999\n", "iter: 8750, loss: 2.7154052526648083\n",
"iter: 8800, loss: 2.581544444644262\n", "iter: 8800, loss: 2.8823739119030183\n",
"iter: 8850, loss: 2.7955539315276674\n", "iter: 8850, loss: 2.93061117755799\n",
"iter: 8900, loss: 2.583812619288763\n", "iter: 8900, loss: 2.658344201617771\n",
"iter: 8950, loss: 2.6446591711649825\n", "iter: 8950, loss: 2.5747124820644887\n",
"iter: 9000, loss: 2.577330000854674\n", "iter: 9000, loss: 2.8281182004307954\n",
"iter: 9050, loss: 2.4657566853288615\n", "iter: 9050, loss: 2.6702445936959895\n",
"iter: 9100, loss: 2.800543680138058\n", "iter: 9100, loss: 2.8030708763485865\n",
"iter: 9150, loss: 2.8939966171544707\n", "iter: 9150, loss: 3.0742075329053966\n",
"iter: 9200, loss: 2.484702325525738\n", "iter: 9200, loss: 2.7834522392787635\n",
"iter: 9250, loss: 2.9708456475469807\n", "iter: 9250, loss: 2.9308865650949025\n",
"iter: 9300, loss: 2.8829837035148858\n", "iter: 9300, loss: 2.776913931453039\n",
"iter: 9350, loss: 2.451061187414896\n", "iter: 9350, loss: 2.7998796779011923\n",
"iter: 9400, loss: 3.144906068983533\n", "iter: 9400, loss: 3.1615792548088795\n",
"iter: 9450, loss: 2.4527184899950787\n", "iter: 9450, loss: 3.2742855516539673\n",
"iter: 9500, loss: 2.665944624832698\n", "iter: 9500, loss: 2.981044085154457\n",
"iter: 9550, loss: 2.5468089370273406\n", "iter: 9550, loss: 2.4407524968101866\n",
"iter: 9600, loss: 2.51169423552165\n", "iter: 9600, loss: 2.624275121037923\n",
"iter: 9650, loss: 2.916568091210864\n", "iter: 9650, loss: 2.4893303714971697\n",
"iter: 9700, loss: 2.8149766059640853\n", "iter: 9700, loss: 2.7211539438906183\n",
"iter: 9750, loss: 2.6544064010362773\n", "iter: 9750, loss: 2.8714180671828133\n",
"iter: 9800, loss: 2.300161985658464\n", "iter: 9800, loss: 2.7188037380396373\n",
"iter: 9850, loss: 2.5070087575912483\n", "iter: 9850, loss: 2.4101966271173385\n",
"iter: 9900, loss: 2.617770311056621\n", "iter: 9900, loss: 2.9492219283542926\n",
"iter: 9950, loss: 2.756971993983738\n", "iter: 9950, loss: 2.547067801430112\n",
"iter: 10000, loss: 2.629019902910504\n" "iter: 10000, loss: 2.8521263429191372\n"
] ]
} }
], ],
@ -740,45 +745,45 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"> we re both in the same class .\n", "> he is a tennis player .\n",
"= jesteśmy oboje w tej samej klasie .\n", "= on jest tenisistą .\n",
"< jesteśmy w w . <EOS>\n", "< jest tenisistą . <EOS>\n",
"\n", "\n",
"> you re telling lies again .\n", "> i m not going to change my mind .\n",
"= znowu kłamiesz .\n", "= nie zamierzam zmieniać zdania .\n",
"< znowu mi . <EOS>\n", "< nie idę do . <EOS>\n",
"\n", "\n",
"> i m glad you re back .\n", "> i m totally confused .\n",
"= cieszę się że wróciliście .\n", "= jestem kompletnie zmieszany .\n",
"< cieszę się że . . <EOS>\n", "< jestem dziś . . <EOS>\n",
"\n", "\n",
"> i m not going to have any fun .\n", "> he is a pioneer in this field .\n",
"= nie będę się bawił .\n", "= jest pionierem w tej dziedzinie .\n",
"< nie wolno się . . <EOS>\n", "< on jest w w . . <EOS>\n",
"\n", "\n",
"> i m practising judo .\n", "> i m so excited .\n",
"= trenuję dżudo .\n", "= jestem taki podekscytowany !\n",
"< jestem . . <EOS>\n", "< jestem jestem głodny . <EOS>\n",
"\n", "\n",
"> you re wasting our time .\n", "> they are a party of six .\n",
"= marnujesz nasz czas .\n", "= jest ich sześć osób .\n",
"< masz ci na . . <EOS>\n", "< oni nie są . . <EOS>\n",
"\n", "\n",
"> he is anxious about her health .\n", "> he is the father of two children .\n",
"= on martwi się o jej zdrowie .\n", "= on jest ojcem dwójki dzieci .\n",
"< jest bardzo z niej . . <EOS>\n", "< on jest na do . . <EOS>\n",
"\n", "\n",
"> you re introverted .\n", "> i am leaving at four .\n",
"= jesteś zamknięty w sobie .\n", "= wychodzę o czwartej .\n",
"< masz . <EOS>\n", "< jestem na . <EOS>\n",
"\n", "\n",
"> she s correct for sure .\n", "> i m not much of a writer .\n",
"= ona z pewnością ma rację .\n", "= pisarz ze mnie żaden .\n",
"< ona jest z z . <EOS>\n", "< nie jestem mnie . . <EOS>\n",
"\n", "\n",
"> they re armed .\n", "> you re disgusting !\n",
"= są uzbrojeni .\n", "= jesteś obrzydliwy !\n",
"< są . . <EOS>\n", "< jesteś obrzydliwy . <EOS>\n",
"\n" "\n"
] ]
} }