From abc73faa44abfadd298a893ee07802334ebe1d90 Mon Sep 17 00:00:00 2001 From: Jakub Pokrywka Date: Sun, 29 May 2022 21:24:53 +0200 Subject: [PATCH] 11 --- cw/11_Model_rekurencyjny_z_atencją.ipynb | 1053 ++++++--------------- 1 file changed, 313 insertions(+), 740 deletions(-) diff --git a/cw/11_Model_rekurencyjny_z_atencją.ipynb b/cw/11_Model_rekurencyjny_z_atencją.ipynb index 4b1a7e8..efca556 100644 --- a/cw/11_Model_rekurencyjny_z_atencją.ipynb +++ b/cw/11_Model_rekurencyjny_z_atencją.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -76,20 +76,7 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# def unicodeToAscii(s):\n", - "# return ''.join(\n", - "# c for c in unicodedata.normalize('NFD', s)\n", - "# if unicodedata.category(c) != 'Mn'\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -103,9 +90,6 @@ "\n", " pol_line = re.sub(r\"([.!?])\", r\" \\1\", pol_line)\n", " pol_line = re.sub(r\"[^a-zA-Z.!?ąćęłńóśźżĄĆĘŁŃÓŚŹŻ]+\", r\" \", pol_line)\n", - " \n", - "# eng_line = unicodeToAscii(eng_line)\n", - "# pol_line = unicodeToAscii(pol_line)\n", "\n", " pairs.append([eng_line, pol_line])\n", "\n", @@ -114,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -123,7 +107,7 @@ "['hi .', 'cześć .']" ] }, - "execution_count": 6, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -134,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -161,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -170,7 +154,7 @@ "['i m ok .', 'ze mną wszystko w porządku .']" ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -181,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -190,7 +174,7 @@ "['i m up .', 'wstałem .']" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -201,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -210,7 +194,7 @@ "['i m tom .', 'jestem tom .']" ] }, - "execution_count": 10, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -219,6 +203,46 @@ "pairs[2]" ] }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1828" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eng_lang.n_words" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2883" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pol_lang.n_words" + ] + }, { "cell_type": "code", "execution_count": 11, @@ -333,11 +357,11 @@ "source": [ "teacher_forcing_ratio = 0.5\n", "\n", - "def train_one_batch(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\n", + "def train_one_batch(input_tensor, target_tensor, encoder, decoder, optimizer, criterion, max_length=MAX_LENGTH):\n", " encoder_hidden = encoder.initHidden()\n", "\n", - " encoder_optimizer.zero_grad()\n", - " decoder_optimizer.zero_grad()\n", + "\n", + " optimizer.zero_grad()\n", "\n", " input_length = input_tensor.size(0)\n", " target_length = target_tensor.size(0)\n", @@ -374,8 +398,7 @@ "\n", " loss.backward()\n", "\n", - " encoder_optimizer.step()\n", - " decoder_optimizer.step()\n", + " optimizer.step()\n", "\n", " return loss.item() / target_length" ] @@ -388,9 +411,10 @@ "source": [ "def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\n", " print_loss_total = 0 # Reset every print_every\n", + " encoder.train()\n", + " decoder.train()\n", "\n", - " encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n", - " decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n", + " optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)\n", " \n", " training_pairs = [random.choice(pairs) for _ in range(n_iters)]\n", " training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], pol_lang)) for p in training_pairs]\n", @@ -406,8 +430,8 @@ " target_tensor,\n", " encoder,\n", " decoder,\n", - " encoder_optimizer,\n", - " decoder_optimizer,\n", + " optimizer,\n", + "\n", " criterion)\n", " \n", " print_loss_total += loss\n", @@ -425,6 +449,8 @@ "outputs": [], "source": [ "def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n", + " encoder.eval()\n", + " decoder.eval()\n", " with torch.no_grad():\n", " input_tensor = tensorFromSentence(sentence, eng_lang)\n", " input_length = input_tensor.size()[0]\n", @@ -433,11 +459,10 @@ " encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n", "\n", " for ei in range(input_length):\n", - " encoder_output, encoder_hidden = encoder(input_tensor[ei],\n", - " encoder_hidden)\n", + " encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n", " encoder_outputs[ei] += encoder_output[0, 0]\n", "\n", - " decoder_input = torch.tensor([[SOS_token]], device=device) # SOS\n", + " decoder_input = torch.tensor([[SOS_token]], device=device)\n", "\n", " decoder_hidden = encoder_hidden\n", "\n", @@ -481,652 +506,11 @@ "cell_type": "code", "execution_count": 19, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "iter: 50, loss: 5.242557571774437\n", - "iter: 100, loss: 4.278488328888303\n", - "iter: 150, loss: 4.345584976514179\n", - "iter: 200, loss: 4.372515664100646\n", - "iter: 250, loss: 4.305285846165248\n", - "iter: 300, loss: 4.340697079749336\n", - "iter: 350, loss: 4.225462563787189\n", - "iter: 400, loss: 3.931193191331531\n", - "iter: 450, loss: 4.255609704971314\n", - "iter: 500, loss: 4.110652269242301\n", - "iter: 550, loss: 4.10534066640763\n", - "iter: 600, loss: 4.21553361088889\n", - "iter: 650, loss: 4.182825716926937\n", - "iter: 700, loss: 4.125172647006929\n", - "iter: 750, loss: 3.92194454044766\n", - "iter: 800, loss: 3.9588410959697904\n", - "iter: 850, loss: 4.045722855431693\n", - "iter: 900, loss: 3.989374102009668\n", - "iter: 950, loss: 3.8188858143791307\n", - "iter: 1000, loss: 4.083652021286979\n", - "iter: 1050, loss: 3.853677998118931\n", - "iter: 1100, loss: 4.240768341064452\n", - "iter: 1150, loss: 3.9389620546991857\n", - "iter: 1200, loss: 3.797674254826137\n", - "iter: 1250, loss: 3.9533765572441957\n", - "iter: 1300, loss: 4.227734768761528\n", - "iter: 1350, loss: 4.098812954554482\n", - "iter: 1400, loss: 3.970981749352953\n", - "iter: 1450, loss: 3.961285910742623\n", - "iter: 1500, loss: 3.7287075826554075\n", - "iter: 1550, loss: 4.011922280311584\n", - "iter: 1600, loss: 3.819334566025507\n", - "iter: 1650, loss: 4.058212718903073\n", - "iter: 1700, loss: 3.7410677611335874\n", - "iter: 1750, loss: 3.6925343862261077\n", - "iter: 1800, loss: 3.686991301317063\n", - "iter: 1850, loss: 4.038253510641673\n", - "iter: 1900, loss: 3.6877949990999133\n", - "iter: 1950, loss: 3.6264906252452302\n", - "iter: 2000, loss: 3.9525268210607845\n", - "iter: 2050, loss: 3.6234626099268596\n", - "iter: 2100, loss: 3.777446182023912\n", - "iter: 2150, loss: 3.6546845786609343\n", - "iter: 2200, loss: 3.5066249518167414\n", - "iter: 2250, loss: 3.4109464065082484\n", - "iter: 2300, loss: 3.3479860273467175\n", - "iter: 2350, loss: 3.4630342653289676\n", - "iter: 2400, loss: 3.565554211571103\n", - "iter: 2450, loss: 3.575563238794841\n", - "iter: 2500, loss: 4.010586709249588\n", - "iter: 2550, loss: 3.7003344478380105\n", - "iter: 2600, loss: 3.8373477258985003\n", - "iter: 2650, loss: 3.622782635825022\n", - "iter: 2700, loss: 3.539290331628588\n", - "iter: 2750, loss: 3.6115450941721594\n", - "iter: 2800, loss: 3.5639050323849633\n", - "iter: 2850, loss: 3.7455426768651083\n", - "iter: 2900, loss: 3.590438606050279\n", - "iter: 2950, loss: 3.6904614701346747\n", - "iter: 3000, loss: 3.8645651381356383\n", - "iter: 3050, loss: 3.704173568400126\n", - "iter: 3100, loss: 3.6185882792548534\n", - "iter: 3150, loss: 3.4422046549131\n", - "iter: 3200, loss: 3.4467696623802184\n", - "iter: 3250, loss: 3.5570392836237725\n", - "iter: 3300, loss: 3.4138823104585923\n", - "iter: 3350, loss: 3.1434556758668686\n", - "iter: 3400, loss: 3.683112149511065\n", - "iter: 3450, loss: 3.548991588297344\n", - "iter: 3500, loss: 3.8323247443380795\n", - "iter: 3550, loss: 3.51619815774191\n", - "iter: 3600, loss: 3.9732376934687297\n", - "iter: 3650, loss: 3.4628570978679356\n", - "iter: 3700, loss: 3.819421215375265\n", - "iter: 3750, loss: 3.7018858858819983\n", - "iter: 3800, loss: 3.332828684034802\n", - "iter: 3850, loss: 3.39565832292466\n", - "iter: 3900, loss: 3.6046563055099\n", - "iter: 3950, loss: 3.4032811139727404\n", - "iter: 4000, loss: 3.188702507541294\n", - "iter: 4050, loss: 3.246296736966995\n", - "iter: 4100, loss: 3.3872493017287484\n", - "iter: 4150, loss: 3.2912982750620166\n", - "iter: 4200, loss: 3.439030005250657\n", - "iter: 4250, loss: 3.6874865720536967\n", - "iter: 4300, loss: 3.2006266547081967\n", - "iter: 4350, loss: 3.3141552084968198\n", - "iter: 4400, loss: 3.1777613387107846\n", - "iter: 4450, loss: 3.306143865358262\n", - "iter: 4500, loss: 3.3490057452519726\n", - "iter: 4550, loss: 3.54855015988577\n", - "iter: 4600, loss: 3.1190093379020696\n", - "iter: 4650, loss: 3.1318349221547455\n", - "iter: 4700, loss: 3.3145397909482317\n", - "iter: 4750, loss: 3.6301960383823935\n", - "iter: 4800, loss: 3.497950396598331\n", - "iter: 4850, loss: 3.433724424384889\n", - "iter: 4900, loss: 3.099926131324163\n", - "iter: 4950, loss: 3.153078259695144\n", - "iter: 5000, loss: 3.4299044117473425\n", - "iter: 5050, loss: 3.2485543521245326\n", - "iter: 5100, loss: 3.288219501253158\n", - "iter: 5150, loss: 3.0275319642793566\n", - "iter: 5200, loss: 3.2333122690518694\n", - "iter: 5250, loss: 3.1438695950281055\n", - "iter: 5300, loss: 3.289688352705941\n", - "iter: 5350, loss: 3.4346048777671085\n", - "iter: 5400, loss: 3.3960607704435075\n", - "iter: 5450, loss: 3.134131056607716\n", - "iter: 5500, loss: 2.88015941856021\n", - "iter: 5550, loss: 3.223093853874812\n", - "iter: 5600, loss: 3.523275021235148\n", - "iter: 5650, loss: 3.2974130417430207\n", - "iter: 5700, loss: 3.291351405779521\n", - "iter: 5750, loss: 3.0594470017069857\n", - "iter: 5800, loss: 3.0294449334144593\n", - "iter: 5850, loss: 3.2555880333885314\n", - "iter: 5900, loss: 2.919657180456888\n", - "iter: 5950, loss: 3.0907614767362195\n", - "iter: 6000, loss: 2.961127914254628\n", - "iter: 6050, loss: 3.4255604133378896\n", - "iter: 6100, loss: 3.113428830744728\n", - "iter: 6150, loss: 3.2713393408457434\n", - "iter: 6200, loss: 2.808141718750909\n", - "iter: 6250, loss: 3.206718180179596\n", - "iter: 6300, loss: 2.961204339458829\n", - "iter: 6350, loss: 3.3583041914379788\n", - "iter: 6400, loss: 2.8745996781455148\n", - "iter: 6450, loss: 3.044813909867453\n", - "iter: 6500, loss: 3.0786628415698103\n", - "iter: 6550, loss: 3.1983558077206693\n", - "iter: 6600, loss: 3.2838380699536156\n", - "iter: 6650, loss: 3.299677872680482\n", - "iter: 6700, loss: 3.0458072693007336\n", - "iter: 6750, loss: 2.8759968113482937\n", - "iter: 6800, loss: 2.611457399186634\n", - "iter: 6850, loss: 3.1191990507443745\n", - "iter: 6900, loss: 2.8746687850649404\n", - "iter: 6950, loss: 3.266799050270565\n", - "iter: 7000, loss: 2.9557879123422834\n", - "iter: 7050, loss: 3.3536233253327623\n", - "iter: 7100, loss: 2.866518679376633\n", - "iter: 7150, loss: 3.0647721849698866\n", - "iter: 7200, loss: 3.0131801981396147\n", - "iter: 7250, loss: 3.3611434789687866\n", - "iter: 7300, loss: 2.896131462626987\n", - "iter: 7350, loss: 3.0051966722579224\n", - "iter: 7400, loss: 2.6453575278766577\n", - "iter: 7450, loss: 3.0411309962424027\n", - "iter: 7500, loss: 3.0933231606710523\n", - "iter: 7550, loss: 3.0312348983022908\n", - "iter: 7600, loss: 3.1757038073766797\n", - "iter: 7650, loss: 3.190331464472272\n", - "iter: 7700, loss: 2.518719242436545\n", - "iter: 7750, loss: 2.9345069105965758\n", - "iter: 7800, loss: 2.8456812357221337\n", - "iter: 7850, loss: 2.9130297107620837\n", - "iter: 7900, loss: 2.979178165594737\n", - "iter: 7950, loss: 2.901021231393965\n", - "iter: 8000, loss: 2.595174813210018\n", - "iter: 8050, loss: 2.7613930717271473\n", - "iter: 8100, loss: 2.746399310149844\n", - "iter: 8150, loss: 2.8843572297663913\n", - "iter: 8200, loss: 2.7994356728735426\n", - "iter: 8250, loss: 2.6970716561135784\n", - "iter: 8300, loss: 2.883459539050147\n", - "iter: 8350, loss: 2.7503165247099735\n", - "iter: 8400, loss: 2.9744199762647114\n", - "iter: 8450, loss: 3.0706924525518273\n", - "iter: 8500, loss: 2.888958851995922\n", - "iter: 8550, loss: 2.719320885154936\n", - "iter: 8600, loss: 2.8181346920444854\n", - "iter: 8650, loss: 2.8235925950890493\n", - "iter: 8700, loss: 3.051045098115527\n", - "iter: 8750, loss: 2.5698431457110815\n", - "iter: 8800, loss: 2.7776481211828807\n", - "iter: 8850, loss: 2.4384212581695075\n", - "iter: 8900, loss: 2.6480511212954445\n", - "iter: 8950, loss: 2.5756836236620706\n", - "iter: 9000, loss: 2.8125146527971534\n", - "iter: 9050, loss: 2.8097832722512504\n", - "iter: 9100, loss: 2.8278016069389533\n", - "iter: 9150, loss: 2.444784381949712\n", - "iter: 9200, loss: 2.8099934362154158\n", - "iter: 9250, loss: 2.984244331113876\n", - "iter: 9300, loss: 2.9806695501161005\n", - "iter: 9350, loss: 2.8827475949923187\n", - "iter: 9400, loss: 3.0439721408420137\n", - "iter: 9450, loss: 2.6807251452415706\n", - "iter: 9500, loss: 2.5094920273621875\n", - "iter: 9550, loss: 2.635116928410909\n", - "iter: 9600, loss: 2.587259496537466\n", - "iter: 9650, loss: 2.6364437070649767\n", - "iter: 9700, loss: 2.6659068493899842\n", - "iter: 9750, loss: 2.3925973146056365\n", - "iter: 9800, loss: 2.8345537455271157\n", - "iter: 9850, loss: 2.3069138811202277\n", - "iter: 9900, loss: 2.319064138798487\n", - "iter: 9950, loss: 2.4867696173228913\n", - "iter: 10000, loss: 2.614620875483468\n", - "iter: 10050, loss: 2.422453577261123\n", - "iter: 10100, loss: 2.643933411677678\n", - "iter: 10150, loss: 2.5282146744349645\n", - "iter: 10200, loss: 2.5393255345310486\n", - "iter: 10250, loss: 2.9825220655032565\n", - "iter: 10300, loss: 2.2635890780091277\n", - "iter: 10350, loss: 2.7769809711168683\n", - "iter: 10400, loss: 2.445600000275506\n", - "iter: 10450, loss: 2.453449849030328\n", - "iter: 10500, loss: 2.5520464683108854\n", - "iter: 10550, loss: 2.577900281663925\n", - "iter: 10600, loss: 2.4218848383086065\n", - "iter: 10650, loss: 2.5381085565317245\n", - "iter: 10700, loss: 2.196764139754431\n", - "iter: 10750, loss: 2.456448502972012\n", - "iter: 10800, loss: 2.560683703441468\n", - "iter: 10850, loss: 2.53125628255284\n", - "iter: 10900, loss: 2.7491349925086603\n", - "iter: 10950, loss: 2.6151628021588404\n", - "iter: 11000, loss: 2.507106682993117\n", - "iter: 11050, loss: 2.369231795661033\n", - "iter: 11100, loss: 2.5730169670676433\n", - "iter: 11150, loss: 2.3010029462254233\n", - "iter: 11200, loss: 2.633150562687526\n", - "iter: 11250, loss: 2.5919544999429163\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "iter: 11300, loss: 2.534775358722323\n", - "iter: 11350, loss: 2.3763021782466343\n", - "iter: 11400, loss: 2.477060817866099\n", - "iter: 11450, loss: 2.240788399741763\n", - "iter: 11500, loss: 2.4560615454931103\n", - "iter: 11550, loss: 2.4167706055300577\n", - "iter: 11600, loss: 2.5485691031482482\n", - "iter: 11650, loss: 2.4385872491881964\n", - "iter: 11700, loss: 2.262665515203325\n", - "iter: 11750, loss: 2.3140043601195015\n", - "iter: 11800, loss: 2.3840308377969834\n", - "iter: 11850, loss: 2.3109344417519044\n", - "iter: 11900, loss: 2.3575586484140825\n", - "iter: 11950, loss: 2.2054063754535855\n", - "iter: 12000, loss: 2.256502300773348\n", - "iter: 12050, loss: 2.4794748330608245\n", - "iter: 12100, loss: 2.337028050218309\n", - "iter: 12150, loss: 2.0973778800964356\n", - "iter: 12200, loss: 2.159631293109485\n", - "iter: 12250, loss: 2.3099975161770034\n", - "iter: 12300, loss: 2.421918697101729\n", - "iter: 12350, loss: 2.531752646151043\n", - "iter: 12400, loss: 2.320960243735995\n", - "iter: 12450, loss: 2.2293582723708387\n", - "iter: 12500, loss: 2.2750969548414623\n", - "iter: 12550, loss: 1.7618893385950534\n", - "iter: 12600, loss: 2.340753597418468\n", - "iter: 12650, loss: 2.297142340731999\n", - "iter: 12700, loss: 2.3628056962887443\n", - "iter: 12750, loss: 2.6683244729306956\n", - "iter: 12800, loss: 2.129389260549394\n", - "iter: 12850, loss: 2.200342266559601\n", - "iter: 12900, loss: 2.2035769688401903\n", - "iter: 12950, loss: 2.2374771443643264\n", - "iter: 13000, loss: 2.174828187268878\n", - "iter: 13050, loss: 2.4299814297888016\n", - "iter: 13100, loss: 2.2743770096472335\n", - "iter: 13150, loss: 2.2950440486839843\n", - "iter: 13200, loss: 2.4582167831375483\n", - "iter: 13250, loss: 2.540857286835474\n", - "iter: 13300, loss: 2.322540373416174\n", - "iter: 13350, loss: 2.319241341783887\n", - "iter: 13400, loss: 2.3551435836969854\n", - "iter: 13450, loss: 2.21491847473856\n", - "iter: 13500, loss: 2.0196543374175118\n", - "iter: 13550, loss: 2.294338379172105\n", - "iter: 13600, loss: 1.8327846462726596\n", - "iter: 13650, loss: 2.035525601328366\n", - "iter: 13700, loss: 2.2429525320794843\n", - "iter: 13750, loss: 2.034286926406243\n", - "iter: 13800, loss: 2.100517734408379\n", - "iter: 13850, loss: 2.0885622107670425\n", - "iter: 13900, loss: 1.90785291909793\n", - "iter: 13950, loss: 2.273535749908478\n", - "iter: 14000, loss: 1.838191468528339\n", - "iter: 14050, loss: 2.3195868289754507\n", - "iter: 14100, loss: 1.8965250493060974\n", - "iter: 14150, loss: 2.10934934569919\n", - "iter: 14200, loss: 2.151934117366397\n", - "iter: 14250, loss: 2.0066685717522152\n", - "iter: 14300, loss: 2.3296401413433134\n", - "iter: 14350, loss: 1.9384719442223746\n", - "iter: 14400, loss: 2.2025605153564425\n", - "iter: 14450, loss: 2.263807416308494\n", - "iter: 14500, loss: 1.9864815442051205\n", - "iter: 14550, loss: 1.7038374454577763\n", - "iter: 14600, loss: 2.274628053700167\n", - "iter: 14650, loss: 2.1628303778625675\n", - "iter: 14700, loss: 1.9897215003796989\n", - "iter: 14750, loss: 1.860605546917234\n", - "iter: 14800, loss: 1.9588362134335529\n", - "iter: 14850, loss: 1.8767746505396707\n", - "iter: 14900, loss: 1.834631380274182\n", - "iter: 14950, loss: 1.9499947649410792\n", - "iter: 15000, loss: 2.0015979091269624\n", - "iter: 15050, loss: 2.0649836547412574\n", - "iter: 15100, loss: 2.249369715940384\n", - "iter: 15150, loss: 1.5817453392441307\n", - "iter: 15200, loss: 2.1706447578157695\n", - "iter: 15250, loss: 1.9688029914564558\n", - "iter: 15300, loss: 2.046964565526871\n", - "iter: 15350, loss: 1.9338763165667892\n", - "iter: 15400, loss: 1.9137448829904438\n", - "iter: 15450, loss: 1.7699638532740727\n", - "iter: 15500, loss: 2.2515631875159245\n", - "iter: 15550, loss: 1.7620117027797395\n", - "iter: 15600, loss: 1.9152411586524003\n", - "iter: 15650, loss: 2.0947861353386017\n", - "iter: 15700, loss: 1.9149094790844687\n", - "iter: 15750, loss: 1.7210240173566909\n", - "iter: 15800, loss: 2.014472983038614\n", - "iter: 15850, loss: 2.1098430752697444\n", - "iter: 15900, loss: 2.023270213549099\n", - "iter: 15950, loss: 1.9570550824488917\n", - "iter: 16000, loss: 1.895675997123832\n", - "iter: 16050, loss: 1.837380247549642\n", - "iter: 16100, loss: 1.894489290089834\n", - "iter: 16150, loss: 2.075172846224573\n", - "iter: 16200, loss: 1.8212170035555248\n", - "iter: 16250, loss: 1.8570367700694095\n", - "iter: 16300, loss: 1.6184977439187818\n", - "iter: 16350, loss: 1.7351362812415\n", - "iter: 16400, loss: 1.872060403579758\n", - "iter: 16450, loss: 1.6218276036712858\n", - "iter: 16500, loss: 1.9870286158758497\n", - "iter: 16550, loss: 1.9007116212835387\n", - "iter: 16600, loss: 1.8743730505156142\n", - "iter: 16650, loss: 1.5293502329928537\n", - "iter: 16700, loss: 1.811881399162232\n", - "iter: 16750, loss: 1.5156562756375658\n", - "iter: 16800, loss: 1.6397469798794813\n", - "iter: 16850, loss: 2.2027597563172145\n", - "iter: 16900, loss: 1.8139538214131006\n", - "iter: 16950, loss: 2.1659815680677927\n", - "iter: 17000, loss: 1.947558910210927\n", - "iter: 17050, loss: 2.0774720856149993\n", - "iter: 17100, loss: 1.7940182881762112\n", - "iter: 17150, loss: 2.1425245618441746\n", - "iter: 17200, loss: 1.6630687274876097\n", - "iter: 17250, loss: 1.7448162170535044\n", - "iter: 17300, loss: 1.8790338722637718\n", - "iter: 17350, loss: 1.96936958753495\n", - "iter: 17400, loss: 1.8035021762762753\n", - "iter: 17450, loss: 1.784786748029883\n", - "iter: 17500, loss: 1.8431302896037933\n", - "iter: 17550, loss: 1.9356805955001288\n", - "iter: 17600, loss: 1.571500998784625\n", - "iter: 17650, loss: 1.849001149414551\n", - "iter: 17700, loss: 1.5969795638758038\n", - "iter: 17750, loss: 1.6012443284591038\n", - "iter: 17800, loss: 1.5525058465600012\n", - "iter: 17850, loss: 1.450256337930286\n", - "iter: 17900, loss: 1.7983906224483532\n", - "iter: 17950, loss: 1.7381368355050921\n", - "iter: 18000, loss: 1.6177345383224033\n", - "iter: 18050, loss: 1.835479336150582\n", - "iter: 18100, loss: 1.5402896869333964\n", - "iter: 18150, loss: 1.5447071926097078\n", - "iter: 18200, loss: 1.6833134629707485\n", - "iter: 18250, loss: 1.8886855756252532\n", - "iter: 18300, loss: 1.6310479882558186\n", - "iter: 18350, loss: 1.6417460731078708\n", - "iter: 18400, loss: 1.7383878009962657\n", - "iter: 18450, loss: 1.6342206524724057\n", - "iter: 18500, loss: 1.5872581603981197\n", - "iter: 18550, loss: 1.287150528927171\n", - "iter: 18600, loss: 1.6059650084300645\n", - "iter: 18650, loss: 1.28275570456045\n", - "iter: 18700, loss: 1.439326602407864\n", - "iter: 18750, loss: 1.7180297046511894\n", - "iter: 18800, loss: 1.6227167361766575\n", - "iter: 18850, loss: 1.437303775454324\n", - "iter: 18900, loss: 1.6929941054639364\n", - "iter: 18950, loss: 1.6776369486933662\n", - "iter: 19000, loss: 1.69069007818661\n", - "iter: 19050, loss: 1.8343193885277191\n", - "iter: 19100, loss: 1.3482130224931808\n", - "iter: 19150, loss: 1.4392069308530717\n", - "iter: 19200, loss: 1.4435342772607769\n", - "iter: 19250, loss: 1.4412190558891447\n", - "iter: 19300, loss: 1.7313999670062743\n", - "iter: 19350, loss: 1.6303069564179768\n", - "iter: 19400, loss: 1.8313010199240274\n", - "iter: 19450, loss: 1.476125830580318\n", - "iter: 19500, loss: 1.784752085032917\n", - "iter: 19550, loss: 1.900799496985617\n", - "iter: 19600, loss: 1.6683086817453778\n", - "iter: 19650, loss: 1.6018399291965693\n", - "iter: 19700, loss: 1.5080324055013201\n", - "iter: 19750, loss: 1.7074753486149838\n", - "iter: 19800, loss: 1.5588394280918059\n", - "iter: 19850, loss: 1.4063752451401856\n", - "iter: 19900, loss: 1.6571519161235722\n", - "iter: 19950, loss: 1.4880279254605846\n", - "iter: 20000, loss: 1.4425315815721234\n", - "iter: 20050, loss: 1.4204049231041045\n", - "iter: 20100, loss: 1.5411449456631194\n", - "iter: 20150, loss: 1.4098666115223417\n", - "iter: 20200, loss: 1.4514436369504011\n", - "iter: 20250, loss: 1.678218051835658\n", - "iter: 20300, loss: 1.3683213942356056\n", - "iter: 20350, loss: 1.4311776501555296\n", - "iter: 20400, loss: 1.44434953142537\n", - "iter: 20450, loss: 1.4809531437674215\n", - "iter: 20500, loss: 1.498182836138067\n", - "iter: 20550, loss: 1.6891843606990486\n", - "iter: 20600, loss: 1.307836448823176\n", - "iter: 20650, loss: 1.3191714194629873\n", - "iter: 20700, loss: 1.435782224451266\n", - "iter: 20750, loss: 1.4501241854064992\n", - "iter: 20800, loss: 1.570673788651587\n", - "iter: 20850, loss: 1.6726866277487031\n", - "iter: 20900, loss: 1.490093153404811\n", - "iter: 20950, loss: 1.3381259351434216\n", - "iter: 21000, loss: 1.4293887265811833\n", - "iter: 21050, loss: 1.5261030488553495\n", - "iter: 21100, loss: 1.4049972703861338\n", - "iter: 21150, loss: 1.3666674501952667\n", - "iter: 21200, loss: 1.544151809760502\n", - "iter: 21250, loss: 1.4767180123480546\n", - "iter: 21300, loss: 1.3458678885580055\n", - "iter: 21350, loss: 1.3158163404729633\n", - "iter: 21400, loss: 1.2743006317423922\n", - "iter: 21450, loss: 1.4159044456604926\n", - "iter: 21500, loss: 1.7186118897502385\n", - "iter: 21550, loss: 1.4735830772358276\n", - "iter: 21600, loss: 1.2575308752939818\n", - "iter: 21650, loss: 1.2709813627033006\n", - "iter: 21700, loss: 1.2383236987832047\n", - "iter: 21750, loss: 1.2756263920948618\n", - "iter: 21800, loss: 1.1783258064417612\n", - "iter: 21850, loss: 1.290928970362459\n", - "iter: 21900, loss: 1.2292051843586895\n", - "iter: 21950, loss: 1.4506985603798\n", - "iter: 22000, loss: 1.2761652381798578\n", - "iter: 22050, loss: 1.258709805733628\n", - "iter: 22100, loss: 1.5169502600658507\n", - "iter: 22150, loss: 1.384094950204804\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "iter: 22200, loss: 1.6116164426425141\n", - "iter: 22250, loss: 1.2793350757350999\n", - "iter: 22300, loss: 1.3017940769725374\n", - "iter: 22350, loss: 1.3736145087998537\n", - "iter: 22400, loss: 1.138627294466609\n", - "iter: 22450, loss: 1.310434480055457\n", - "iter: 22500, loss: 1.3961503054835491\n", - "iter: 22550, loss: 1.3724343404996964\n", - "iter: 22600, loss: 1.3563060793403596\n", - "iter: 22650, loss: 1.5050592094120523\n", - "iter: 22700, loss: 1.3752844895202485\n", - "iter: 22750, loss: 1.1653902740024384\n", - "iter: 22800, loss: 1.586769177135967\n", - "iter: 22850, loss: 1.2032956434232844\n", - "iter: 22900, loss: 1.3740250343538463\n", - "iter: 22950, loss: 1.1050300681430192\n", - "iter: 23000, loss: 1.2519222570203599\n", - "iter: 23050, loss: 1.42347088217688\n", - "iter: 23100, loss: 1.4116886373684991\n", - "iter: 23150, loss: 1.1194891034222785\n", - "iter: 23200, loss: 1.2006812089085581\n", - "iter: 23250, loss: 1.4476829275015806\n", - "iter: 23300, loss: 1.3017136716679447\n", - "iter: 23350, loss: 1.3050824448059475\n", - "iter: 23400, loss: 1.0946670590383667\n", - "iter: 23450, loss: 1.0267400648877734\n", - "iter: 23500, loss: 1.3339079014630544\n", - "iter: 23550, loss: 1.192698698204543\n", - "iter: 23600, loss: 1.2177529160323597\n", - "iter: 23650, loss: 1.0835593536988137\n", - "iter: 23700, loss: 1.0673790020138498\n", - "iter: 23750, loss: 1.3244361790158448\n", - "iter: 23800, loss: 1.1204376552663156\n", - "iter: 23850, loss: 1.2015556238284189\n", - "iter: 23900, loss: 1.4070206288629112\n", - "iter: 23950, loss: 1.3217124197293841\n", - "iter: 24000, loss: 1.0488405134649506\n", - "iter: 24050, loss: 1.1713006507924626\n", - "iter: 24100, loss: 1.2951658061903621\n", - "iter: 24150, loss: 1.2749495941726934\n", - "iter: 24200, loss: 1.2141112795103162\n", - "iter: 24250, loss: 1.3290269901402414\n", - "iter: 24300, loss: 1.094365536570549\n", - "iter: 24350, loss: 1.133138121331495\n", - "iter: 24400, loss: 1.3418169331314074\n", - "iter: 24450, loss: 0.9847527233404773\n", - "iter: 24500, loss: 1.1087985199188426\n", - "iter: 24550, loss: 1.4006639858985706\n", - "iter: 24600, loss: 1.1466205491246213\n", - "iter: 24650, loss: 1.1214664732799642\n", - "iter: 24700, loss: 1.1177163749280432\n", - "iter: 24750, loss: 1.05219458625714\n", - "iter: 24800, loss: 1.1949661584328566\n", - "iter: 24850, loss: 0.9802025896845353\n", - "iter: 24900, loss: 1.1272975780748655\n", - "iter: 24950, loss: 1.0976827581269404\n", - "iter: 25000, loss: 0.9013028181819688\n", - "iter: 25050, loss: 1.3180778384589484\n", - "iter: 25100, loss: 1.0977117643299557\n", - "iter: 25150, loss: 0.9444285869991021\n", - "iter: 25200, loss: 1.336973425586545\n", - "iter: 25250, loss: 1.2987125150627556\n", - "iter: 25300, loss: 1.0681130346740995\n", - "iter: 25350, loss: 0.9836248108498631\n", - "iter: 25400, loss: 1.1549646752675378\n", - "iter: 25450, loss: 1.0397938099617048\n", - "iter: 25500, loss: 1.4253321852816476\n", - "iter: 25550, loss: 1.397831358559548\n", - "iter: 25600, loss: 0.8681884022117372\n", - "iter: 25650, loss: 0.949661937920347\n", - "iter: 25700, loss: 1.018096680844114\n", - "iter: 25750, loss: 1.0033835446210135\n", - "iter: 25800, loss: 0.9399867170606815\n", - "iter: 25850, loss: 0.9365767020531115\n", - "iter: 25900, loss: 1.2080267537056453\n", - "iter: 25950, loss: 1.0215099297222634\n", - "iter: 26000, loss: 0.9733565044677448\n", - "iter: 26050, loss: 1.0712914910318834\n", - "iter: 26100, loss: 0.8407332850779805\n", - "iter: 26150, loss: 0.9271211279460363\n", - "iter: 26200, loss: 0.9953902960416108\n", - "iter: 26250, loss: 1.0131704654341178\n", - "iter: 26300, loss: 1.0885028305432156\n", - "iter: 26350, loss: 1.0190075791875521\n", - "iter: 26400, loss: 1.009052420553707\n", - "iter: 26450, loss: 1.0815212898623379\n", - "iter: 26500, loss: 0.9892340009240876\n", - "iter: 26550, loss: 1.0516380755560737\n", - "iter: 26600, loss: 0.9344196589528803\n", - "iter: 26650, loss: 0.8953249894132216\n", - "iter: 26700, loss: 0.9229552195980435\n", - "iter: 26750, loss: 0.7424087155598496\n", - "iter: 26800, loss: 0.911013327536129\n", - "iter: 26850, loss: 1.1781759474883002\n", - "iter: 26900, loss: 1.196274289493523\n", - "iter: 26950, loss: 1.0227981455389943\n", - "iter: 27000, loss: 0.9916679235928586\n", - "iter: 27050, loss: 0.9636169400480058\n", - "iter: 27100, loss: 0.8002338881918359\n", - "iter: 27150, loss: 0.800919440870247\n", - "iter: 27200, loss: 0.8211033871329966\n", - "iter: 27250, loss: 0.8155000005123162\n", - "iter: 27300, loss: 0.876837944473539\n", - "iter: 27350, loss: 1.1260614515467298\n", - "iter: 27400, loss: 1.058864346462583\n", - "iter: 27450, loss: 1.1114834662898192\n", - "iter: 27500, loss: 0.9796440882084387\n", - "iter: 27550, loss: 1.0277935135303036\n", - "iter: 27600, loss: 0.6979781560635284\n", - "iter: 27650, loss: 0.770827453770808\n", - "iter: 27700, loss: 1.1471699211550135\n", - "iter: 27750, loss: 0.8712478535033409\n", - "iter: 27800, loss: 0.7957819575319688\n", - "iter: 27850, loss: 1.0939111870155924\n", - "iter: 27900, loss: 0.9194521397224494\n", - "iter: 27950, loss: 0.8920607558945345\n", - "iter: 28000, loss: 0.8829188095186908\n", - "iter: 28050, loss: 0.9212011002366033\n", - "iter: 28100, loss: 0.7731392620366715\n", - "iter: 28150, loss: 1.056102939241699\n", - "iter: 28200, loss: 0.9831677025327132\n", - "iter: 28250, loss: 1.071929881365999\n", - "iter: 28300, loss: 0.9135961269267967\n", - "iter: 28350, loss: 0.8095226630355632\n", - "iter: 28400, loss: 0.9595384959959911\n", - "iter: 28450, loss: 0.7839641324215465\n", - "iter: 28500, loss: 0.9889460563829968\n", - "iter: 28550, loss: 1.0575634596305232\n", - "iter: 28600, loss: 1.05014324463218\n", - "iter: 28650, loss: 0.9521020337228501\n", - "iter: 28700, loss: 0.8122104218034515\n", - "iter: 28750, loss: 0.9600319336676408\n", - "iter: 28800, loss: 0.7290925218548092\n", - "iter: 28850, loss: 0.8589948218661168\n", - "iter: 28900, loss: 0.8876770969496832\n", - "iter: 28950, loss: 0.7668700665647076\n", - "iter: 29000, loss: 0.8810090623952094\n", - "iter: 29050, loss: 0.9807037507650397\n", - "iter: 29100, loss: 0.6704667443845952\n", - "iter: 29150, loss: 0.6698679181308975\n", - "iter: 29200, loss: 0.8776328837161972\n", - "iter: 29250, loss: 0.8806386950718503\n", - "iter: 29300, loss: 0.6410340730618862\n", - "iter: 29350, loss: 0.8755547849377472\n", - "iter: 29400, loss: 0.8818342795334163\n", - "iter: 29450, loss: 0.7442211986517623\n", - "iter: 29500, loss: 0.8927219600469348\n", - "iter: 29550, loss: 1.019919359203842\n", - "iter: 29600, loss: 0.8808109327583087\n", - "iter: 29650, loss: 0.8205070998280766\n", - "iter: 29700, loss: 1.019214930534363\n", - "iter: 29750, loss: 0.8730531409016206\n", - "iter: 29800, loss: 0.7633821407521053\n", - "iter: 29850, loss: 0.796077705860138\n", - "iter: 29900, loss: 0.7018148700419874\n", - "iter: 29950, loss: 1.1195493836871218\n", - "iter: 30000, loss: 0.8907366043790467\n", - "iter: 30050, loss: 0.9264667236958704\n", - "iter: 30100, loss: 1.0352731366356211\n", - "iter: 30150, loss: 0.7005343800724028\n", - "iter: 30200, loss: 0.9168639244217249\n", - "iter: 30250, loss: 0.8539114402177789\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [19]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m encoder1 \u001b[38;5;241m=\u001b[39m EncoderRNN(eng_lang\u001b[38;5;241m.\u001b[39mn_words, hidden_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 3\u001b[0m attn_decoder1 \u001b[38;5;241m=\u001b[39m AttnDecoderRNN(hidden_size, pol_lang\u001b[38;5;241m.\u001b[39mn_words, dropout_p\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainIters\u001b[49m\u001b[43m(\u001b[49m\u001b[43mencoder1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_decoder1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m75000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprint_every\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m)\u001b[49m\n", - "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36mtrainIters\u001b[0;34m(encoder, decoder, n_iters, print_every, learning_rate)\u001b[0m\n\u001b[1;32m 14\u001b[0m input_tensor \u001b[38;5;241m=\u001b[39m training_pair[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 15\u001b[0m target_tensor \u001b[38;5;241m=\u001b[39m training_pair[\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m---> 17\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_one_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_optimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_optimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 25\u001b[0m print_loss_total \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m print_every \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", - "Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36mtrain_one_batch\u001b[0;34m(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m decoder_input\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m==\u001b[39m EOS_token:\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m---> 42\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 44\u001b[0m encoder_optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 45\u001b[0m decoder_optimizer\u001b[38;5;241m.\u001b[39mstep()\n", - "File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/_tensor.py:363\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 356\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 357\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 361\u001b[0m create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m 362\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 363\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "hidden_size = 256\n", "encoder1 = EncoderRNN(eng_lang.n_words, hidden_size).to(device)\n", - "attn_decoder1 = AttnDecoderRNN(hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)\n", - "\n", - "trainIters(encoder1, attn_decoder1, 75000, print_every=50)" + "attn_decoder1 = AttnDecoderRNN(hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)" ] }, { @@ -1138,51 +522,211 @@ "name": "stdout", "output_type": "stream", "text": [ - "> he s a very important person .\n", - "= on jest bardzo ważnym człowiekiem .\n", - "< on jest bardzo ważnym człowiekiem . \n", - "\n", - "> i m beautiful .\n", - "= jestem piękny .\n", - "< jestem piękna . \n", - "\n", - "> we re quite certain of that .\n", - "= jesteśmy tego całkiem pewni .\n", - "< jesteśmy tego całkiem pewni . \n", - "\n", - "> we are all looking forward to seeing you .\n", - "= miło nam będzie ponownie się z panem spotkać .\n", - "< miło nam nam ponownie się z tobą . . \n", - "\n", - "> i m inside .\n", - "= jestem w środku .\n", - "< jestem w środku . \n", - "\n", - "> i m giving up smoking .\n", - "= rzucam palenie .\n", - "< rzuciłem palenie . \n", - "\n", - "> we re not arguing .\n", - "= nie kłócimy się .\n", - "< nie wychodzimy . \n", - "\n", - "> i m not prepared to do that yet .\n", - "= nie jestem jeszcze przygotowany żeby to zrobić .\n", - "< nie jestem jeszcze przygotowany żeby to zrobić . \n", - "\n", - "> i m a free man .\n", - "= jestem wolnym człowiekiem .\n", - "< jestem wolnym człowiekiem . \n", - "\n", - "> i m still on the clock .\n", - "= jeszcze jestem w pracy .\n", - "< wciąż jestem w domu . \n", - "\n" + "iter: 50, loss: 4.699110437711081\n", + "iter: 100, loss: 4.241607124086411\n", + "iter: 150, loss: 4.14866822333563\n", + "iter: 200, loss: 4.175457921709334\n", + "iter: 250, loss: 4.304153789429438\n", + "iter: 300, loss: 4.304717092377798\n", + "iter: 350, loss: 4.316578052808368\n", + "iter: 400, loss: 4.379952565056937\n", + "iter: 450, loss: 4.086811531929743\n", + "iter: 500, loss: 4.252370147765628\n", + "iter: 550, loss: 4.02257244164603\n", + "iter: 600, loss: 4.271288591505989\n", + "iter: 650, loss: 4.037527732379852\n", + "iter: 700, loss: 3.808401109422956\n", + "iter: 750, loss: 4.01287091629089\n", + "iter: 800, loss: 4.185342459905715\n", + "iter: 850, loss: 3.8268170519934763\n", + "iter: 900, loss: 3.9197384970074607\n", + "iter: 950, loss: 4.225208856279888\n", + "iter: 1000, loss: 4.128686094178094\n", + "iter: 1050, loss: 3.9167927505553712\n", + "iter: 1100, loss: 4.015269571940103\n", + "iter: 1150, loss: 4.168424199830918\n", + "iter: 1200, loss: 4.302581990559896\n", + "iter: 1250, loss: 3.7335942743392225\n", + "iter: 1300, loss: 3.9526881422315334\n", + "iter: 1350, loss: 3.8640213389169604\n", + "iter: 1400, loss: 4.101886716827512\n", + "iter: 1450, loss: 3.6106392740067985\n", + "iter: 1500, loss: 4.0689067233857665\n", + "iter: 1550, loss: 4.02288844353812\n", + "iter: 1600, loss: 3.572508715992883\n", + "iter: 1650, loss: 3.972692446489183\n", + "iter: 1700, loss: 3.8709554294404525\n", + "iter: 1750, loss: 3.9830204631714583\n", + "iter: 1800, loss: 3.7999766263961794\n", + "iter: 1850, loss: 3.7026816112578858\n", + "iter: 1900, loss: 3.833205360775902\n", + "iter: 1950, loss: 3.650638633606925\n", + "iter: 2000, loss: 3.748746382418133\n", + "iter: 2050, loss: 3.762590566922748\n", + "iter: 2100, loss: 3.5997376789214117\n", + "iter: 2150, loss: 3.919283335610041\n", + "iter: 2200, loss: 3.8638847478684912\n", + "iter: 2250, loss: 3.4960837801675946\n", + "iter: 2300, loss: 3.685049927688782\n", + "iter: 2350, loss: 3.5716699722759313\n", + "iter: 2400, loss: 3.8988636863874997\n", + "iter: 2450, loss: 3.752788569586617\n", + "iter: 2500, loss: 3.802307117961702\n", + "iter: 2550, loss: 3.6420236970432227\n", + "iter: 2600, loss: 3.6925315249912325\n", + "iter: 2650, loss: 3.8897219879059572\n", + "iter: 2700, loss: 3.6327851654537153\n", + "iter: 2750, loss: 3.396957855118645\n", + "iter: 2800, loss: 3.5258935768112307\n", + "iter: 2850, loss: 3.605109554866003\n", + "iter: 2900, loss: 3.533288128330594\n", + "iter: 2950, loss: 3.4583421086054\n", + "iter: 3000, loss: 3.403592811425526\n", + "iter: 3050, loss: 3.5225157889411567\n", + "iter: 3100, loss: 3.4702517202846592\n", + "iter: 3150, loss: 3.4234997159185867\n", + "iter: 3200, loss: 3.5447632862348404\n", + "iter: 3250, loss: 3.1799173504133074\n", + "iter: 3300, loss: 3.7154814013905\n", + "iter: 3350, loss: 3.4188442155444445\n", + "iter: 3400, loss: 3.6557525696527393\n", + "iter: 3450, loss: 3.52880564416401\n", + "iter: 3500, loss: 3.4842312318408295\n", + "iter: 3550, loss: 3.5256399853570115\n", + "iter: 3600, loss: 3.70226228499034\n", + "iter: 3650, loss: 3.2043497113424633\n", + "iter: 3700, loss: 3.4575287022439256\n", + "iter: 3750, loss: 3.4197605448374664\n", + "iter: 3800, loss: 3.290345760890417\n", + "iter: 3850, loss: 3.300158274309976\n", + "iter: 3900, loss: 3.3362661438139645\n", + "iter: 3950, loss: 3.4947717628630373\n", + "iter: 4000, loss: 3.5624450731353154\n", + "iter: 4050, loss: 3.438600626892514\n", + "iter: 4100, loss: 3.142976412258451\n", + "iter: 4150, loss: 3.332818130595344\n", + "iter: 4200, loss: 3.31952378733196\n", + "iter: 4250, loss: 3.5315058948123252\n", + "iter: 4300, loss: 3.6603812535074023\n", + "iter: 4350, loss: 3.35295347692853\n", + "iter: 4400, loss: 3.374297706498041\n", + "iter: 4450, loss: 3.09948105843105\n", + "iter: 4500, loss: 3.16787886763376\n", + "iter: 4550, loss: 3.455794033330583\n", + "iter: 4600, loss: 3.1263191164258926\n", + "iter: 4650, loss: 3.3723485524995\n", + "iter: 4700, loss: 3.147410953930445\n", + "iter: 4750, loss: 3.4546711923281346\n", + "iter: 4800, loss: 3.449277176016852\n", + "iter: 4850, loss: 3.197799104531606\n", + "iter: 4900, loss: 3.239384971149383\n", + "iter: 4950, loss: 3.696369633697328\n", + "iter: 5000, loss: 3.2114706332191587\n", + "iter: 5050, loss: 3.400943172795432\n", + "iter: 5100, loss: 3.298932059106372\n", + "iter: 5150, loss: 3.3697974183445907\n", + "iter: 5200, loss: 3.31293656670858\n", + "iter: 5250, loss: 3.1415378823658773\n", + "iter: 5300, loss: 3.1587839283867494\n", + "iter: 5350, loss: 3.3505903312440903\n", + "iter: 5400, loss: 3.247191356802744\n", + "iter: 5450, loss: 3.236625145200699\n", + "iter: 5500, loss: 3.19994143747148\n", + "iter: 5550, loss: 3.2911239544626265\n", + "iter: 5600, loss: 3.1855649600483122\n", + "iter: 5650, loss: 3.157031875163789\n", + "iter: 5700, loss: 3.2652817099586366\n", + "iter: 5750, loss: 3.3272896775593837\n", + "iter: 5800, loss: 3.3162626687458583\n", + "iter: 5850, loss: 3.1342987139338536\n", + "iter: 5900, loss: 3.29665669613036\n", + "iter: 5950, loss: 3.232995939807286\n", + "iter: 6000, loss: 3.0922561403758935\n", + "iter: 6050, loss: 3.1034776155835107\n", + "iter: 6100, loss: 3.1502840874081564\n", + "iter: 6150, loss: 2.915993771098909\n", + "iter: 6200, loss: 2.994096033270397\n", + "iter: 6250, loss: 3.1102042265392487\n", + "iter: 6300, loss: 2.8244728108587718\n", + "iter: 6350, loss: 3.117810124692462\n", + "iter: 6400, loss: 3.0742526639529637\n", + "iter: 6450, loss: 2.8390014954218787\n", + "iter: 6500, loss: 3.1032223067510687\n", + "iter: 6550, loss: 2.912433739840038\n", + "iter: 6600, loss: 2.9158696003490023\n", + "iter: 6650, loss: 3.2617745389030093\n", + "iter: 6700, loss: 3.295657290466248\n", + "iter: 6750, loss: 2.975928121767347\n", + "iter: 6800, loss: 3.0057779382069914\n", + "iter: 6850, loss: 2.85224422507059\n", + "iter: 6900, loss: 3.0329934195336836\n", + "iter: 6950, loss: 3.1322296761255415\n", + "iter: 7000, loss: 2.893814939192363\n", + "iter: 7050, loss: 2.934597730205173\n", + "iter: 7100, loss: 3.267660904082041\n", + "iter: 7150, loss: 3.1199153114651867\n", + "iter: 7200, loss: 2.8414319788160776\n", + "iter: 7250, loss: 3.1128779797251256\n", + "iter: 7300, loss: 3.1182169116565155\n", + "iter: 7350, loss: 3.101384938853128\n", + "iter: 7400, loss: 2.9836614183395627\n", + "iter: 7450, loss: 2.7261425285036602\n", + "iter: 7500, loss: 2.7323913456977356\n", + "iter: 7550, loss: 3.284201001443559\n", + "iter: 7600, loss: 2.9473503636405587\n", + "iter: 7650, loss: 2.861012626541986\n", + "iter: 7700, loss: 2.6726747900872003\n", + "iter: 7750, loss: 2.760957624162947\n", + "iter: 7800, loss: 2.647666095211393\n", + "iter: 7850, loss: 2.7921250426428657\n", + "iter: 7900, loss: 2.9527213778495787\n", + "iter: 7950, loss: 2.790506172891647\n", + "iter: 8000, loss: 2.8376009529431663\n", + "iter: 8050, loss: 3.0387913953690298\n", + "iter: 8100, loss: 2.908381733046637\n", + "iter: 8150, loss: 2.7374484727761104\n", + "iter: 8200, loss: 2.84610585779614\n", + "iter: 8250, loss: 2.8532650649736793\n", + "iter: 8300, loss: 2.856347685723078\n", + "iter: 8350, loss: 2.6641267998710503\n", + "iter: 8400, loss: 2.7541870554590973\n", + "iter: 8450, loss: 2.814719854824126\n", + "iter: 8500, loss: 2.6979909611694395\n", + "iter: 8550, loss: 2.577483120327904\n", + "iter: 8600, loss: 2.7884950113561415\n", + "iter: 8650, loss: 3.0236114144552317\n", + "iter: 8700, loss: 2.5850161893329924\n", + "iter: 8750, loss: 2.992550043756999\n", + "iter: 8800, loss: 2.581544444644262\n", + "iter: 8850, loss: 2.7955539315276674\n", + "iter: 8900, loss: 2.583812619288763\n", + "iter: 8950, loss: 2.6446591711649825\n", + "iter: 9000, loss: 2.577330000854674\n", + "iter: 9050, loss: 2.4657566853288615\n", + "iter: 9100, loss: 2.800543680138058\n", + "iter: 9150, loss: 2.8939966171544707\n", + "iter: 9200, loss: 2.484702325525738\n", + "iter: 9250, loss: 2.9708456475469807\n", + "iter: 9300, loss: 2.8829837035148858\n", + "iter: 9350, loss: 2.451061187414896\n", + "iter: 9400, loss: 3.144906068983533\n", + "iter: 9450, loss: 2.4527184899950787\n", + "iter: 9500, loss: 2.665944624832698\n", + "iter: 9550, loss: 2.5468089370273406\n", + "iter: 9600, loss: 2.51169423552165\n", + "iter: 9650, loss: 2.916568091210864\n", + "iter: 9700, loss: 2.8149766059640853\n", + "iter: 9750, loss: 2.6544064010362773\n", + "iter: 9800, loss: 2.300161985658464\n", + "iter: 9850, loss: 2.5070087575912483\n", + "iter: 9900, loss: 2.617770311056621\n", + "iter: 9950, loss: 2.756971993983738\n", + "iter: 10000, loss: 2.629019902910504\n" ] } ], "source": [ - "evaluateRandomly(encoder1, attn_decoder1)" + "trainIters(encoder1, attn_decoder1, 10_000, print_every=50)" ] }, { @@ -1191,26 +735,55 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "['i m ok .', 'ze mną wszystko w porządku .']" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "> we re both in the same class .\n", + "= jesteśmy oboje w tej samej klasie .\n", + "< jesteśmy w w . \n", + "\n", + "> you re telling lies again .\n", + "= znowu kłamiesz .\n", + "< znowu mi . \n", + "\n", + "> i m glad you re back .\n", + "= cieszę się że wróciliście .\n", + "< cieszę się że . . \n", + "\n", + "> i m not going to have any fun .\n", + "= nie będę się bawił .\n", + "< nie wolno się . . \n", + "\n", + "> i m practising judo .\n", + "= trenuję dżudo .\n", + "< jestem . . \n", + "\n", + "> you re wasting our time .\n", + "= marnujesz nasz czas .\n", + "< masz ci na . . \n", + "\n", + "> he is anxious about her health .\n", + "= on martwi się o jej zdrowie .\n", + "< jest bardzo z niej . . \n", + "\n", + "> you re introverted .\n", + "= jesteś zamknięty w sobie .\n", + "< masz . \n", + "\n", + "> she s correct for sure .\n", + "= ona z pewnością ma rację .\n", + "< ona jest z z . \n", + "\n", + "> they re armed .\n", + "= są uzbrojeni .\n", + "< są . . \n", + "\n" + ] } ], "source": [ - "pairs[0]" + "evaluateRandomly(encoder1, attn_decoder1)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {