11 in progress

This commit is contained in:
Jakub Pokrywka 2022-05-29 19:05:03 +02:00
parent faa0ded62f
commit 195fc25de1
3 changed files with 46462 additions and 248 deletions

View File

@ -80,11 +80,11 @@
"metadata": {},
"outputs": [],
"source": [
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )"
"# def unicodeToAscii(s):\n",
"# return ''.join(\n",
"# c for c in unicodedata.normalize('NFD', s)\n",
"# if unicodedata.category(c) != 'Mn'\n",
"# )"
]
},
{
@ -94,20 +94,20 @@
"outputs": [],
"source": [
"pairs = []\n",
"with open('data/eng-fra.txt') as f:\n",
"with open('data/eng-pol.txt') as f:\n",
" for line in f:\n",
" eng_line, fra_line = line.lower().rstrip().split('\\t')\n",
" eng_line, pol_line = line.lower().rstrip().split('\\t')\n",
"\n",
" eng_line = re.sub(r\"([.!?])\", r\" \\1\", eng_line)\n",
" eng_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", eng_line)\n",
"\n",
" fra_line = re.sub(r\"([.!?])\", r\" \\1\", fra_line)\n",
" fra_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", fra_line)\n",
" pol_line = re.sub(r\"([.!?])\", r\" \\1\", pol_line)\n",
" pol_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", pol_line)\n",
" \n",
" eng_line = unicodeToAscii(eng_line)\n",
" fra_line = unicodeToAscii(fra_line)\n",
"# eng_line = unicodeToAscii(eng_line)\n",
"# pol_line = unicodeToAscii(pol_line)\n",
"\n",
" pairs.append([eng_line, fra_line])\n",
" pairs.append([eng_line, pol_line])\n",
"\n",
"\n"
]
@ -120,7 +120,7 @@
{
"data": {
"text/plain": [
"['run !', 'cours !']"
"['hi .', 'cze .']"
]
},
"execution_count": 6,
@ -152,11 +152,11 @@
"pairs = [p for p in pairs if p[0].startswith(eng_prefixes)]\n",
"\n",
"eng_lang = Lang()\n",
"fra_lang = Lang()\n",
"pol_lang = Lang()\n",
"\n",
"for pair in pairs:\n",
" eng_lang.addSentence(pair[0])\n",
" fra_lang.addSentence(pair[1])"
" pol_lang.addSentence(pair[1])"
]
},
{
@ -167,7 +167,7 @@
{
"data": {
"text/plain": [
"['i m .', 'j ai ans .']"
"['i m ok .', 'ze mn wszystko w porz dku .']"
]
},
"execution_count": 8,
@ -187,7 +187,7 @@
{
"data": {
"text/plain": [
"['i m ok .', 'je vais bien .']"
"['i m up .', 'wsta em .']"
]
},
"execution_count": 9,
@ -207,7 +207,7 @@
{
"data": {
"text/plain": [
"['i m ok .', ' a va .']"
"['i m tom .', 'jestem tom .']"
]
},
"execution_count": 10,
@ -347,8 +347,7 @@
" loss = 0\n",
"\n",
" for ei in range(input_length):\n",
" encoder_output, encoder_hidden = encoder(\n",
" input_tensor[ei], encoder_hidden)\n",
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
" encoder_outputs[ei] = encoder_output[0, 0]\n",
"\n",
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
@ -358,18 +357,14 @@
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n",
"\n",
" if use_teacher_forcing:\n",
" # Teacher forcing: Feed the target as the next input\n",
" for di in range(target_length):\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\n",
" decoder_input, decoder_hidden, encoder_outputs)\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
" loss += criterion(decoder_output, target_tensor[di])\n",
" decoder_input = target_tensor[di] # Teacher forcing\n",
"\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" for di in range(target_length):\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\n",
" decoder_input, decoder_hidden, encoder_outputs)\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
" topv, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze().detach() # detach from history as input\n",
"\n",
@ -398,7 +393,7 @@
" decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n",
" \n",
" training_pairs = [random.choice(pairs) for _ in range(n_iters)]\n",
" training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], fra_lang)) for p in training_pairs]\n",
" training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], pol_lang)) for p in training_pairs]\n",
" \n",
" criterion = nn.NLLLoss()\n",
"\n",
@ -410,7 +405,7 @@
" loss = train_one_batch(input_tensor,\n",
" target_tensor,\n",
" encoder,\n",
" encoder,\n",
" decoder,\n",
" encoder_optimizer,\n",
" decoder_optimizer,\n",
" criterion)\n",
@ -458,7 +453,7 @@
" decoded_words.append('<EOS>')\n",
" break\n",
" else:\n",
" decoded_words.append(fra_lang.index2word[topi.item()])\n",
" decoded_words.append(pol_lang.index2word[topi.item()])\n",
"\n",
" decoder_input = topi.squeeze().detach()\n",
"\n",
@ -484,244 +479,251 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iter: 50, loss: 4.78930813773473\n",
"iter: 100, loss: 4.554949267220875\n",
"iter: 150, loss: 4.238516052685087\n",
"iter: 200, loss: 4.279887475513276\n",
"iter: 250, loss: 4.1802274973884455\n",
"iter: 300, loss: 4.2113521892305394\n",
"iter: 350, loss: 4.266180963228619\n",
"iter: 400, loss: 4.225914733432588\n",
"iter: 450, loss: 4.1369073431075565\n",
"iter: 500, loss: 3.9906799076019768\n",
"iter: 550, loss: 3.842005534717016\n",
"iter: 600, loss: 4.081443620484972\n",
"iter: 650, loss: 4.030401878296383\n",
"iter: 700, loss: 3.869014380984837\n",
"iter: 750, loss: 3.8505467753031906\n",
"iter: 800, loss: 3.855170104072209\n",
"iter: 850, loss: 3.675745445599631\n",
"iter: 900, loss: 3.9147777624584386\n",
"iter: 950, loss: 3.766264297788106\n",
"iter: 1000, loss: 3.6813155986997814\n",
"iter: 1050, loss: 3.9307321495934144\n",
"iter: 1100, loss: 3.9047770059525027\n",
"iter: 1150, loss: 3.655722749588981\n",
"iter: 1200, loss: 3.540693810886806\n",
"iter: 1250, loss: 3.790360960324605\n",
"iter: 1300, loss: 3.7472636015907153\n",
"iter: 1350, loss: 3.641857419574072\n",
"iter: 1400, loss: 3.717327400631375\n",
"iter: 1450, loss: 3.4848567311423166\n",
"iter: 1500, loss: 3.56774485397339\n",
"iter: 1550, loss: 3.460277635226175\n",
"iter: 1600, loss: 3.241899683013796\n",
"iter: 1650, loss: 3.50151977614751\n",
"iter: 1700, loss: 3.621569488313462\n",
"iter: 1750, loss: 3.3851226735947626\n",
"iter: 1800, loss: 3.346289497057597\n",
"iter: 1850, loss: 3.5180823354569695\n",
"iter: 1900, loss: 3.433616197676886\n",
"iter: 1950, loss: 3.6162788327080864\n",
"iter: 2000, loss: 3.4990604458763492\n",
"iter: 2050, loss: 3.3144700173423405\n",
"iter: 2100, loss: 3.2962356294980135\n",
"iter: 2150, loss: 3.1448448797861728\n",
"iter: 2200, loss: 3.6958242581534018\n",
"iter: 2250, loss: 3.5269318538241925\n",
"iter: 2300, loss: 3.180744191850934\n",
"iter: 2350, loss: 3.317159715145354\n",
"iter: 2400, loss: 3.638545340795366\n",
"iter: 2450, loss: 3.7591161967988995\n",
"iter: 2500, loss: 3.3513535446742218\n",
"iter: 2550, loss: 3.4554441847271393\n",
"iter: 2600, loss: 2.9394915195343994\n",
"iter: 2650, loss: 3.370902210848673\n",
"iter: 2700, loss: 3.4259227318839423\n",
"iter: 2750, loss: 3.4058353806904393\n",
"iter: 2800, loss: 3.467306881359647\n",
"iter: 2850, loss: 3.222254538074372\n",
"iter: 2900, loss: 3.3392559226808087\n",
"iter: 2950, loss: 3.4203980594362533\n",
"iter: 3000, loss: 3.3507530433563955\n",
"iter: 3050, loss: 3.4326547555317966\n",
"iter: 3100, loss: 3.1755515496390205\n",
"iter: 3150, loss: 3.3925877854634847\n",
"iter: 3200, loss: 3.223531436912598\n",
"iter: 3250, loss: 3.3089625614862603\n",
"iter: 3300, loss: 3.367763715501815\n",
"iter: 3350, loss: 3.4278301871163497\n",
"iter: 3400, loss: 3.373292277381534\n",
"iter: 3450, loss: 3.3497054475829717\n",
"iter: 3500, loss: 3.402910869681646\n",
"iter: 3550, loss: 3.072571641732776\n",
"iter: 3600, loss: 3.2611226563832116\n",
"iter: 3650, loss: 3.231520605495998\n",
"iter: 3700, loss: 3.3788801974569043\n",
"iter: 3750, loss: 3.176644308181036\n",
"iter: 3800, loss: 3.2255533708693496\n",
"iter: 3850, loss: 3.2362594686387083\n",
"iter: 3900, loss: 3.095807164230044\n",
"iter: 3950, loss: 3.2343999077024916\n",
"iter: 4000, loss: 3.3681417366512245\n",
"iter: 4050, loss: 3.0732023419879737\n",
"iter: 4100, loss: 3.0663742440617283\n",
"iter: 4150, loss: 3.396770855048347\n",
"iter: 4200, loss: 3.4262332421522292\n",
"iter: 4250, loss: 3.060121847773354\n",
"iter: 4300, loss: 2.895130627753243\n",
"iter: 4350, loss: 3.017712699065133\n",
"iter: 4400, loss: 3.1289404028559487\n",
"iter: 4450, loss: 3.163725920904249\n",
"iter: 4500, loss: 3.3627441662606743\n",
"iter: 4550, loss: 3.409984823173947\n",
"iter: 4600, loss: 2.8944704760899618\n",
"iter: 4650, loss: 3.0016444209568083\n",
"iter: 4700, loss: 2.8574393688837683\n",
"iter: 4750, loss: 3.1946328716656525\n",
"iter: 4800, loss: 2.768447057353125\n",
"iter: 4850, loss: 3.075327144675784\n",
"iter: 4900, loss: 3.268370175997416\n",
"iter: 4950, loss: 3.1798231331053235\n",
"iter: 5000, loss: 3.3217560536218063\n",
"iter: 5050, loss: 3.006732604223585\n",
"iter: 5100, loss: 3.3575944598061698\n",
"iter: 5150, loss: 2.9057663469655175\n",
"iter: 5200, loss: 2.8928466574502374\n",
"iter: 5250, loss: 3.061066797528948\n",
"iter: 5300, loss: 3.35562970057745\n",
"iter: 5350, loss: 2.9118076042901895\n",
"iter: 5400, loss: 2.9514354321918783\n",
"iter: 5450, loss: 2.9334804391406832\n",
"iter: 5500, loss: 3.204634138440329\n",
"iter: 5550, loss: 2.8140748963961526\n",
"iter: 5600, loss: 3.011708143741365\n",
"iter: 5650, loss: 3.323859388586074\n",
"iter: 5700, loss: 2.8442912295810756\n",
"iter: 5750, loss: 2.80684267281729\n",
"iter: 5800, loss: 3.1174840584860903\n",
"iter: 5850, loss: 2.6991389470478837\n",
"iter: 5900, loss: 2.9698236653237116\n",
"iter: 5950, loss: 3.0238281039586137\n",
"iter: 6000, loss: 2.8812837354947645\n",
"iter: 6050, loss: 3.1709352504639394\n",
"iter: 6100, loss: 2.937920509209709\n",
"iter: 6150, loss: 3.178728113076043\n",
"iter: 6200, loss: 2.8974244089429337\n",
"iter: 6250, loss: 2.809626478180052\n",
"iter: 6300, loss: 2.781241159703996\n",
"iter: 6350, loss: 2.9004218400395105\n",
"iter: 6400, loss: 2.9118271145669246\n",
"iter: 6450, loss: 2.8842602037096787\n",
"iter: 6500, loss: 2.9489114957536966\n",
"iter: 6550, loss: 2.9503131193130736\n",
"iter: 6600, loss: 2.8961831474304187\n",
"iter: 6650, loss: 3.002027267266834\n",
"iter: 6700, loss: 3.0047303264103236\n",
"iter: 6750, loss: 2.958453589060949\n",
"iter: 6800, loss: 2.9524990789852446\n",
"iter: 6850, loss: 2.935619188210321\n",
"iter: 6900, loss: 2.9734530233807033\n",
"iter: 6950, loss: 2.785320390822396\n",
"iter: 7000, loss: 3.1911680922054106\n",
"iter: 7050, loss: 2.7732513120363635\n",
"iter: 7100, loss: 2.7432456348282948\n",
"iter: 7150, loss: 2.823985375283256\n",
"iter: 7200, loss: 2.927504679808541\n",
"iter: 7250, loss: 3.0693400076760184\n",
"iter: 7300, loss: 2.666468213043515\n",
"iter: 7350, loss: 2.808132514378382\n",
"iter: 7400, loss: 2.558679431067573\n",
"iter: 7450, loss: 2.6974468813850763\n",
"iter: 7500, loss: 2.8497490201223457\n",
"iter: 7550, loss: 2.7490190564337236\n",
"iter: 7600, loss: 2.8300208840067427\n",
"iter: 7650, loss: 2.793417969741518\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [19]\u001b[0m, in \u001b[0;36m<cell line: 5>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m encoder1 \u001b[38;5;241m=\u001b[39m EncoderRNN(eng_lang\u001b[38;5;241m.\u001b[39mn_words, hidden_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 3\u001b[0m attn_decoder1 \u001b[38;5;241m=\u001b[39m AttnDecoderRNN(hidden_size, fra_lang\u001b[38;5;241m.\u001b[39mn_words, dropout_p\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainIters\u001b[49m\u001b[43m(\u001b[49m\u001b[43mencoder1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_decoder1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m75000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprint_every\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36mtrainIters\u001b[0;34m(encoder, decoder, n_iters, print_every, plot_every, learning_rate)\u001b[0m\n\u001b[1;32m 16\u001b[0m input_tensor \u001b[38;5;241m=\u001b[39m training_pair[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 17\u001b[0m target_tensor \u001b[38;5;241m=\u001b[39m training_pair[\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m---> 19\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_tensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdecoder_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m print_loss_total \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\n\u001b[1;32m 22\u001b[0m plot_loss_total \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\n",
"Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m decoder_input\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m==\u001b[39m EOS_token:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m---> 48\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 50\u001b[0m encoder_optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 51\u001b[0m decoder_optimizer\u001b[38;5;241m.\u001b[39mstep()\n",
"File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/_tensor.py:363\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 356\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 357\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 361\u001b[0m create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m 362\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 363\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"iter: 50, loss: 5.000807713402643\n",
"iter: 100, loss: 4.439269823452783\n",
"iter: 150, loss: 3.9193654258516095\n",
"iter: 200, loss: 4.392944496881395\n",
"iter: 250, loss: 4.093038458445715\n",
"iter: 300, loss: 4.424980944542659\n",
"iter: 350, loss: 3.981394485715835\n",
"iter: 400, loss: 4.333685203370593\n",
"iter: 450, loss: 3.9591501615615123\n",
"iter: 500, loss: 3.9112882453070745\n",
"iter: 550, loss: 4.02278929338001\n",
"iter: 600, loss: 4.193090805341327\n",
"iter: 650, loss: 4.0906043112315835\n",
"iter: 700, loss: 4.469698131742931\n",
"iter: 750, loss: 4.176360895232548\n",
"iter: 800, loss: 3.961828211148579\n",
"iter: 850, loss: 4.261641813959393\n",
"iter: 900, loss: 4.051715278474111\n",
"iter: 950, loss: 3.936853767228505\n",
"iter: 1000, loss: 4.225432455638099\n",
"iter: 1050, loss: 4.045197415472971\n",
"iter: 1100, loss: 4.320344743092855\n",
"iter: 1150, loss: 4.053225604799058\n",
"iter: 1200, loss: 3.743754985476297\n",
"iter: 1250, loss: 4.0527504539035615\n",
"iter: 1300, loss: 3.84758229040721\n",
"iter: 1350, loss: 4.045899712789627\n",
"iter: 1400, loss: 4.027170557158334\n",
"iter: 1450, loss: 4.250136273232718\n",
"iter: 1500, loss: 3.895784365865919\n",
"iter: 1550, loss: 4.033517143960983\n",
"iter: 1600, loss: 4.067692023458934\n",
"iter: 1650, loss: 3.943578155487303\n",
"iter: 1700, loss: 3.638787496930078\n",
"iter: 1750, loss: 3.6410217295752636\n",
"iter: 1800, loss: 3.8924306627757965\n",
"iter: 1850, loss: 4.000204294613429\n",
"iter: 1900, loss: 3.8232511097136\n",
"iter: 1950, loss: 3.878676666108388\n",
"iter: 2000, loss: 3.9427240886536845\n",
"iter: 2050, loss: 3.7359260752693064\n",
"iter: 2100, loss: 3.583097653464666\n",
"iter: 2150, loss: 3.8278237684265024\n",
"iter: 2200, loss: 3.9119961933408463\n",
"iter: 2250, loss: 3.8753220474152346\n",
"iter: 2300, loss: 3.8338965735359802\n",
"iter: 2350, loss: 3.4894873487381712\n",
"iter: 2400, loss: 3.566151720009153\n",
"iter: 2450, loss: 3.937922410420009\n",
"iter: 2500, loss: 3.5345082195070057\n",
"iter: 2550, loss: 3.775564970758225\n",
"iter: 2600, loss: 3.864645612398783\n",
"iter: 2650, loss: 3.9066238069837063\n",
"iter: 2700, loss: 4.0819177106524265\n",
"iter: 2750, loss: 3.655153612878587\n",
"iter: 2800, loss: 3.832113747127473\n",
"iter: 2850, loss: 3.5925060623335456\n",
"iter: 2900, loss: 3.491001639260187\n",
"iter: 2950, loss: 3.5009806160094232\n",
"iter: 3000, loss: 3.6677673985693184\n",
"iter: 3050, loss: 3.781239900210547\n",
"iter: 3100, loss: 3.473299116104368\n",
"iter: 3150, loss: 3.7532493569813066\n",
"iter: 3200, loss: 3.7904585500293306\n",
"iter: 3250, loss: 3.6127893707487324\n",
"iter: 3300, loss: 3.4757489145445453\n",
"iter: 3350, loss: 3.7090715601784847\n",
"iter: 3400, loss: 3.8198574437792336\n",
"iter: 3450, loss: 3.509964802068377\n",
"iter: 3500, loss: 3.612169361614045\n",
"iter: 3550, loss: 3.641026579652514\n",
"iter: 3600, loss: 3.8201526030434483\n",
"iter: 3650, loss: 3.5652526591997287\n",
"iter: 3700, loss: 3.742421626257518\n",
"iter: 3750, loss: 4.003867071651277\n",
"iter: 3800, loss: 3.659059532135253\n",
"iter: 3850, loss: 3.641981271872445\n",
"iter: 3900, loss: 3.5502949162059356\n",
"iter: 3950, loss: 3.560595460755485\n",
"iter: 4000, loss: 3.5651848596542597\n",
"iter: 4050, loss: 3.980170504395925\n",
"iter: 4100, loss: 3.3924002220214367\n",
"iter: 4150, loss: 3.6649077605217233\n",
"iter: 4200, loss: 3.340204861981528\n",
"iter: 4250, loss: 3.722639773754848\n",
"iter: 4300, loss: 3.589223196249159\n",
"iter: 4350, loss: 3.4467484310770793\n",
"iter: 4400, loss: 3.4151901176921897\n",
"iter: 4450, loss: 3.4896546234630392\n",
"iter: 4500, loss: 3.2113779149963744\n",
"iter: 4550, loss: 3.5685467066235015\n",
"iter: 4600, loss: 3.005555194105421\n",
"iter: 4650, loss: 3.6020915983820716\n",
"iter: 4700, loss: 3.633627172273303\n",
"iter: 4750, loss: 3.4529481847551127\n",
"iter: 4800, loss: 3.4479807695207154\n",
"iter: 4850, loss: 3.370973790963491\n",
"iter: 4900, loss: 3.539276809162564\n",
"iter: 4950, loss: 3.3183354888189416\n",
"iter: 5000, loss: 3.521332158444421\n",
"iter: 5050, loss: 3.314378255844116\n",
"iter: 5100, loss: 3.291964127449762\n",
"iter: 5150, loss: 3.4429656072344086\n",
"iter: 5200, loss: 3.5413768560848538\n",
"iter: 5250, loss: 3.585603856238107\n",
"iter: 5300, loss: 3.470469724049644\n",
"iter: 5350, loss: 3.4666152168379893\n",
"iter: 5400, loss: 3.1305627430885563\n",
"iter: 5450, loss: 3.337137906922235\n",
"iter: 5500, loss: 3.481247283072699\n",
"iter: 5550, loss: 3.517226897428906\n",
"iter: 5600, loss: 3.1901850409886183\n",
"iter: 5650, loss: 3.136146711447883\n",
"iter: 5700, loss: 3.404250585170019\n",
"iter: 5750, loss: 3.3665729104375073\n",
"iter: 5800, loss: 3.382146033839574\n",
"iter: 5850, loss: 3.4272568195433837\n",
"iter: 5900, loss: 3.322702169350215\n",
"iter: 5950, loss: 3.156406671554324\n",
"iter: 6000, loss: 3.194001044719938\n",
"iter: 6050, loss: 3.3348103672814755\n",
"iter: 6100, loss: 3.150647495882852\n",
"iter: 6150, loss: 3.1009463010212728\n",
"iter: 6200, loss: 3.3785942046377393\n",
"iter: 6250, loss: 3.3160466527711776\n",
"iter: 6300, loss: 3.1596272509590024\n",
"iter: 6350, loss: 3.2589193917304753\n",
"iter: 6400, loss: 3.297462665050749\n",
"iter: 6450, loss: 3.3298678997206306\n",
"iter: 6500, loss: 3.219574876160848\n",
"iter: 6550, loss: 3.3395619553195104\n",
"iter: 6600, loss: 2.9891018758047196\n",
"iter: 6650, loss: 3.1851753817437185\n",
"iter: 6700, loss: 3.0209535363590896\n",
"iter: 6750, loss: 3.15220423432759\n",
"iter: 6800, loss: 3.181441980475471\n",
"iter: 6850, loss: 2.918750543064541\n",
"iter: 6900, loss: 3.2590200382944134\n",
"iter: 6950, loss: 3.187785402199578\n",
"iter: 7000, loss: 3.1073317580677213\n",
"iter: 7050, loss: 3.2191209546497896\n",
"iter: 7100, loss: 3.2027250674868397\n",
"iter: 7150, loss: 2.828316307037596\n",
"iter: 7200, loss: 2.8388766735886777\n",
"iter: 7250, loss: 2.778842180978684\n",
"iter: 7300, loss: 3.285732759347039\n",
"iter: 7350, loss: 3.0465498041349734\n",
"iter: 7400, loss: 2.90309523902999\n",
"iter: 7450, loss: 2.7295303400736004\n",
"iter: 7500, loss: 2.907297393454446\n",
"iter: 7550, loss: 3.1439063924077963\n",
"iter: 7600, loss: 3.2378484228376356\n",
"iter: 7650, loss: 3.0929804128919316\n",
"iter: 7700, loss: 3.0129570432239117\n",
"iter: 7750, loss: 2.707492174629181\n",
"iter: 7800, loss: 2.852806848832539\n",
"iter: 7850, loss: 2.983840656045883\n",
"iter: 7900, loss: 2.6098039440124756\n",
"iter: 7950, loss: 2.8175843656252293\n",
"iter: 8000, loss: 3.017819283258348\n",
"iter: 8050, loss: 2.728099891352275\n",
"iter: 8100, loss: 2.94138666140087\n",
"iter: 8150, loss: 3.004456134924813\n",
"iter: 8200, loss: 2.909780698662713\n",
"iter: 8250, loss: 2.8520988211707463\n",
"iter: 8300, loss: 2.9205126920351905\n",
"iter: 8350, loss: 3.1615525522080685\n",
"iter: 8400, loss: 2.8823572458918134\n",
"iter: 8450, loss: 2.990696503003438\n",
"iter: 8500, loss: 2.722038128603072\n",
"iter: 8550, loss: 2.7890086468212183\n",
"iter: 8600, loss: 2.7701356183233714\n",
"iter: 8650, loss: 2.8187452931555486\n",
"iter: 8700, loss: 2.927999514186192\n",
"iter: 8750, loss: 3.0153564615930826\n",
"iter: 8800, loss: 2.988208478534032\n",
"iter: 8850, loss: 3.053433906763319\n",
"iter: 8900, loss: 2.8472830426125295\n",
"iter: 8950, loss: 2.9679218861943206\n",
"iter: 9000, loss: 2.722358681913406\n",
"iter: 9050, loss: 2.995666239821722\n",
"iter: 9100, loss: 2.8067044997139585\n",
"iter: 9150, loss: 2.762981554493072\n",
"iter: 9200, loss: 2.8366338660906236\n",
"iter: 9250, loss: 2.877190364905766\n",
"iter: 9300, loss: 2.6378051905518487\n",
"iter: 9350, loss: 3.064765093697442\n",
"iter: 9400, loss: 2.5961536618868513\n",
"iter: 9450, loss: 2.786036056007658\n",
"iter: 9500, loss: 2.6443762784609715\n",
"iter: 9550, loss: 2.7273754563028847\n",
"iter: 9600, loss: 2.68890615716813\n",
"iter: 9650, loss: 2.525617115732223\n",
"iter: 9700, loss: 2.711592395033155\n",
"iter: 9750, loss: 2.540444574356079\n",
"iter: 9800, loss: 2.8242833649090358\n",
"iter: 9850, loss: 2.644202707573535\n",
"iter: 9900, loss: 2.7373070236084946\n",
"iter: 9950, loss: 3.0115960283960614\n",
"iter: 10000, loss: 2.8879434264046813\n",
"iter: 10050, loss: 2.562242189869048\n",
"iter: 10100, loss: 2.8641940906653325\n",
"iter: 10150, loss: 2.7755310885944056\n",
"iter: 10200, loss: 2.633019772166298\n",
"iter: 10250, loss: 2.6914108280454356\n",
"iter: 10300, loss: 2.764466069902692\n",
"iter: 10350, loss: 2.638823566330804\n",
"iter: 10400, loss: 2.6221462763756036\n",
"iter: 10450, loss: 2.8230800466423944\n",
"iter: 10500, loss: 2.772455602169037\n",
"iter: 10550, loss: 2.600414518220085\n",
"iter: 10600, loss: 2.7080593706161262\n",
"iter: 10650, loss: 2.4712089688513013\n",
"iter: 10700, loss: 2.6253130605485704\n",
"iter: 10750, loss: 2.558527778141082\n",
"iter: 10800, loss: 2.7869244644944633\n",
"iter: 10850, loss: 2.585347386742394\n",
"iter: 10900, loss: 2.5044392397517248\n",
"iter: 10950, loss: 2.596850109872364\n",
"iter: 11000, loss: 2.928512234038776\n",
"iter: 11050, loss: 2.5913034356851425\n",
"iter: 11100, loss: 2.679621921558229\n"
]
}
],
"source": [
"hidden_size = 256\n",
"encoder1 = EncoderRNN(eng_lang.n_words, hidden_size).to(device)\n",
"attn_decoder1 = AttnDecoderRNN(hidden_size, fra_lang.n_words, dropout_p=0.1).to(device)\n",
"attn_decoder1 = AttnDecoderRNN(hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)\n",
"\n",
"trainIters(encoder1, attn_decoder1, 75000, print_every=50)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> you re sad .\n",
"= tu es triste .\n",
"< vous tes . . <EOS>\n",
"\n",
"> she is sewing a dress .\n",
"= elle coud une robe .\n",
"< elle est une une . . <EOS>\n",
"\n",
"> he is suffering from a headache .\n",
"= il souffre d un mal de t te .\n",
"< il est un un un un . <EOS>\n",
"\n",
"> i m glad to see you .\n",
"= je suis heureux de vous voir .\n",
"< je suis content de vous voir . <EOS>\n",
"\n",
"> you are only young once .\n",
"= on n est jeune qu une fois .\n",
"< vous tes trop plus une enfant . <EOS>\n",
"\n",
"> you re so sweet .\n",
"= vous tes si gentille !\n",
"< vous tes trop si . <EOS>\n",
"\n",
"> i m running out of closet space .\n",
"= je manque d espace dans mon placard .\n",
"< je suis un de de <EOS>\n",
"\n",
"> i m sort of an extrovert .\n",
"= je suis en quelque sorte extraverti .\n",
"< je suis un un . . <EOS>\n",
"\n",
"> i m out of practice .\n",
"= je manque de pratique .\n",
"< j ai ai pas de <EOS>\n",
"\n",
"> you re the last hope for humanity .\n",
"= tu es le dernier espoir de l humanit .\n",
"< vous tes le la la . . <EOS>\n",
"\n"
]
}
],
"outputs": [],
"source": [
"evaluateRandomly(encoder1, attn_decoder1)"
]

46211
cw/data/eng-pol.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
Plik eng-pol.txt pochodzi z eng-pol.txt phttps://www.manythings.org/anki/o