{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "8Lrtmwas7Y6x" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n", "\n" ] } ], "source": [ "import numpy as np\n", "import keras\n", "import os\n", "from pathlib import Path\n", "import requests\n", "import zipfile" ] }, { "cell_type": "markdown", "metadata": { "id": "BkwAAv7b7i9u" }, "source": [ "### Download the data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "oiE7pYD18jr-" }, "outputs": [], "source": [ "url = \"https://www.manythings.org/anki/pol-eng.zip\"\n", "local_filename = url.split('/')[-1]\n", "dirpath = Path().absolute()\n", "\n", "headers = {\n", " 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'\n", "}\n", "\n", "response = requests.get(url, headers=headers)\n", "if response.status_code == 200:\n", " with open(local_filename, 'wb') as f:\n", " f.write(response.content)\n", " with zipfile.ZipFile(local_filename, 'r') as zip_ref:\n", " zip_ref.extractall(dirpath)" ] }, { "cell_type": "markdown", "metadata": { "id": "he25FWpj-5Lk" }, "source": [ "### Configuration" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "X8pq11jj-90z" }, "outputs": [], "source": [ "batch_size = 64 # Batch size for training.\n", "epochs = 200 # Number of epochs to train for.\n", "latent_dim = 256 # Latent dimensionality of the encoding space.\n", "num_samples = 10000 # Number of samples to train on.\n", "# Path to the data txt file on disk.\n", "data_path = os.path.join(dirpath, \"pol.txt\")" ] }, { "cell_type": "markdown", "metadata": { "id": "P9Ivwa14_Er2" }, "source": [ "### Prepare the data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eZbcKeoF-3Gv", "outputId": "a2e0bc6e-fd9b-40c9-8757-3da4436ddfdb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of samples: 10000\n", "Number of unique input tokens: 70\n", "Number of unique output tokens: 87\n", "Max sequence length for inputs: 19\n", "Max sequence length for outputs: 52\n" ] } ], "source": [ "# Vectorize the data.\n", "input_texts = []\n", "target_texts = []\n", "input_characters = set()\n", "target_characters = set()\n", "with open(data_path, \"r\", encoding=\"utf-8\") as f:\n", " lines = f.read().split(\"\\n\")\n", "for line in lines[: min(num_samples, len(lines) - 1)]:\n", " input_text, target_text, _ = line.split(\"\\t\")\n", " # We use \"tab\" as the \"start sequence\" character\n", " # for the targets, and \"\\n\" as \"end sequence\" character.\n", " target_text = \"\\t\" + target_text + \"\\n\"\n", " input_texts.append(input_text)\n", " target_texts.append(target_text)\n", " for char in input_text:\n", " if char not in input_characters:\n", " input_characters.add(char)\n", " for char in target_text:\n", " if char not in target_characters:\n", " target_characters.add(char)\n", "\n", "input_characters = sorted(list(input_characters))\n", "target_characters = sorted(list(target_characters))\n", "num_encoder_tokens = len(input_characters)\n", "num_decoder_tokens = len(target_characters)\n", "max_encoder_seq_length = max([len(txt) for txt in input_texts])\n", "max_decoder_seq_length = max([len(txt) for txt in target_texts])\n", "\n", "print(\"Number of samples:\", len(input_texts))\n", "print(\"Number of unique input tokens:\", num_encoder_tokens)\n", "print(\"Number of unique output tokens:\", num_decoder_tokens)\n", "print(\"Max sequence length for inputs:\", max_encoder_seq_length)\n", "print(\"Max sequence length for outputs:\", max_decoder_seq_length)\n", "\n", "input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])\n", "target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])\n", "\n", "encoder_input_data = np.zeros(\n", " (len(input_texts), max_encoder_seq_length, num_encoder_tokens),\n", " dtype=\"float32\",\n", ")\n", "decoder_input_data = np.zeros(\n", " (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n", " dtype=\"float32\",\n", ")\n", "decoder_target_data = np.zeros(\n", " (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n", " dtype=\"float32\",\n", ")\n", "\n", "for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n", " for t, char in enumerate(input_text):\n", " encoder_input_data[i, t, input_token_index[char]] = 1.0\n", " encoder_input_data[i, t + 1 :, input_token_index[\" \"]] = 1.0\n", " for t, char in enumerate(target_text):\n", " # decoder_target_data is ahead of decoder_input_data by one timestep\n", " decoder_input_data[i, t, target_token_index[char]] = 1.0\n", " if t > 0:\n", " # decoder_target_data will be ahead by one timestep\n", " # and will not include the start character.\n", " decoder_target_data[i, t - 1, target_token_index[char]] = 1.0\n", " decoder_input_data[i, t + 1 :, target_token_index[\" \"]] = 1.0\n", " decoder_target_data[i, t:, target_token_index[\" \"]] = 1.0" ] }, { "cell_type": "markdown", "metadata": { "id": "xuxzuM6dBpil" }, "source": [ "### Build the model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "SyABMORlBtGX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\backend.py:1398: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.\n", "\n" ] } ], "source": [ "# Define an input sequence and process it.\n", "encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))\n", "encoder = keras.layers.LSTM(latent_dim, return_state=True)\n", "encoder_outputs, state_h, state_c = encoder(encoder_inputs)\n", "\n", "# We discard `encoder_outputs` and only keep the states.\n", "encoder_states = [state_h, state_c]\n", "\n", "# Set up the decoder, using `encoder_states` as initial state.\n", "decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))\n", "\n", "# We set up our decoder to return full output sequences,\n", "# and to return internal states as well. We don't use the\n", "# return states in the training model, but we will use them in inference.\n", "decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)\n", "decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)\n", "decoder_dense = keras.layers.Dense(num_decoder_tokens, activation=\"softmax\")\n", "decoder_outputs = decoder_dense(decoder_outputs)\n", "\n", "# Define the model that will turn\n", "# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`\n", "model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "4Kjv669-BxFF" }, "source": [ "### Train the model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rsebz819Bzfu", "outputId": "0996aac7-eeda-46ee-e80e-1fb55693976f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\optimizers\\__init__.py:309: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", "\n", "Epoch 1/200\n", "WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\utils\\tf_utils.py:492: The name tf.ragged.RaggedTensorValue is deprecated. Please use tf.compat.v1.ragged.RaggedTensorValue instead.\n", "\n", "WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\engine\\base_layer_utils.py:384: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.\n", "\n", "125/125 [==============================] - 20s 121ms/step - loss: 1.3745 - accuracy: 0.6987 - val_loss: 1.4872 - val_accuracy: 0.6530\n", "Epoch 2/200\n", "125/125 [==============================] - 12s 94ms/step - loss: 1.0822 - accuracy: 0.7185 - val_loss: 1.2406 - val_accuracy: 0.6688\n", "Epoch 3/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.9703 - accuracy: 0.7424 - val_loss: 1.1089 - val_accuracy: 0.7065\n", "Epoch 4/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.8733 - accuracy: 0.7671 - val_loss: 1.0149 - val_accuracy: 0.7165\n", "Epoch 5/200\n", "125/125 [==============================] - 18s 145ms/step - loss: 0.7839 - accuracy: 0.7835 - val_loss: 0.9353 - val_accuracy: 0.7399\n", "Epoch 6/200\n", "125/125 [==============================] - 16s 129ms/step - loss: 0.7367 - accuracy: 0.7946 - val_loss: 0.8968 - val_accuracy: 0.7464\n", "Epoch 7/200\n", "125/125 [==============================] - 15s 121ms/step - loss: 0.7030 - accuracy: 0.8006 - val_loss: 0.8677 - val_accuracy: 0.7572\n", "Epoch 8/200\n", "125/125 [==============================] - 14s 112ms/step - loss: 0.6812 - accuracy: 0.8056 - val_loss: 0.8536 - val_accuracy: 0.7566\n", "Epoch 9/200\n", "125/125 [==============================] - 14s 114ms/step - loss: 0.6642 - accuracy: 0.8100 - val_loss: 0.8439 - val_accuracy: 0.7569\n", "Epoch 10/200\n", "125/125 [==============================] - 15s 117ms/step - loss: 0.6492 - accuracy: 0.8133 - val_loss: 0.8181 - val_accuracy: 0.7648\n", "Epoch 11/200\n", "125/125 [==============================] - 17s 136ms/step - loss: 0.6369 - accuracy: 0.8162 - val_loss: 0.8118 - val_accuracy: 0.7664\n", "Epoch 12/200\n", "125/125 [==============================] - 14s 115ms/step - loss: 0.6252 - accuracy: 0.8193 - val_loss: 0.7932 - val_accuracy: 0.7709\n", "Epoch 13/200\n", "125/125 [==============================] - 15s 124ms/step - loss: 0.6147 - accuracy: 0.8216 - val_loss: 0.7842 - val_accuracy: 0.7722\n", "Epoch 14/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.6040 - accuracy: 0.8244 - val_loss: 0.7801 - val_accuracy: 0.7749\n", "Epoch 15/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.5944 - accuracy: 0.8273 - val_loss: 0.7741 - val_accuracy: 0.7758\n", "Epoch 16/200\n", "125/125 [==============================] - 13s 108ms/step - loss: 0.5841 - accuracy: 0.8302 - val_loss: 0.7631 - val_accuracy: 0.7788\n", "Epoch 17/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.5739 - accuracy: 0.8337 - val_loss: 0.7413 - val_accuracy: 0.7838\n", "Epoch 18/200\n", "125/125 [==============================] - 14s 108ms/step - loss: 0.5627 - accuracy: 0.8363 - val_loss: 0.7355 - val_accuracy: 0.7862\n", "Epoch 19/200\n", "125/125 [==============================] - 14s 110ms/step - loss: 0.5532 - accuracy: 0.8390 - val_loss: 0.7239 - val_accuracy: 0.7899\n", "Epoch 20/200\n", "125/125 [==============================] - 15s 121ms/step - loss: 0.5441 - accuracy: 0.8418 - val_loss: 0.7179 - val_accuracy: 0.7916\n", "Epoch 21/200\n", "125/125 [==============================] - 15s 117ms/step - loss: 0.5360 - accuracy: 0.8437 - val_loss: 0.7164 - val_accuracy: 0.7915\n", "Epoch 22/200\n", "125/125 [==============================] - 15s 123ms/step - loss: 0.5270 - accuracy: 0.8469 - val_loss: 0.7067 - val_accuracy: 0.7947\n", "Epoch 23/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.5186 - accuracy: 0.8493 - val_loss: 0.7051 - val_accuracy: 0.7962\n", "Epoch 24/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.5107 - accuracy: 0.8519 - val_loss: 0.6904 - val_accuracy: 0.7994\n", "Epoch 25/200\n", "125/125 [==============================] - 18s 142ms/step - loss: 0.5022 - accuracy: 0.8542 - val_loss: 0.6898 - val_accuracy: 0.7998\n", "Epoch 26/200\n", "125/125 [==============================] - 14s 111ms/step - loss: 0.4945 - accuracy: 0.8565 - val_loss: 0.6775 - val_accuracy: 0.8041\n", "Epoch 27/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.4875 - accuracy: 0.8584 - val_loss: 0.6717 - val_accuracy: 0.8064\n", "Epoch 28/200\n", "125/125 [==============================] - 14s 114ms/step - loss: 0.4806 - accuracy: 0.8608 - val_loss: 0.6649 - val_accuracy: 0.8077\n", "Epoch 29/200\n", "125/125 [==============================] - 15s 122ms/step - loss: 0.4728 - accuracy: 0.8627 - val_loss: 0.6616 - val_accuracy: 0.8085\n", "Epoch 30/200\n", "125/125 [==============================] - 15s 117ms/step - loss: 0.4659 - accuracy: 0.8649 - val_loss: 0.6574 - val_accuracy: 0.8094\n", "Epoch 31/200\n", "125/125 [==============================] - 15s 120ms/step - loss: 0.4596 - accuracy: 0.8672 - val_loss: 0.6557 - val_accuracy: 0.8113\n", "Epoch 32/200\n", "125/125 [==============================] - 15s 117ms/step - loss: 0.4527 - accuracy: 0.8690 - val_loss: 0.6500 - val_accuracy: 0.8124\n", "Epoch 33/200\n", "125/125 [==============================] - 15s 119ms/step - loss: 0.4464 - accuracy: 0.8706 - val_loss: 0.6459 - val_accuracy: 0.8141\n", "Epoch 34/200\n", "125/125 [==============================] - 14s 111ms/step - loss: 0.4403 - accuracy: 0.8727 - val_loss: 0.6470 - val_accuracy: 0.8141\n", "Epoch 35/200\n", "125/125 [==============================] - 13s 108ms/step - loss: 0.4343 - accuracy: 0.8741 - val_loss: 0.6338 - val_accuracy: 0.8175\n", "Epoch 36/200\n", "125/125 [==============================] - 14s 108ms/step - loss: 0.4282 - accuracy: 0.8757 - val_loss: 0.6314 - val_accuracy: 0.8190\n", "Epoch 37/200\n", "125/125 [==============================] - 14s 108ms/step - loss: 0.4223 - accuracy: 0.8776 - val_loss: 0.6343 - val_accuracy: 0.8183\n", "Epoch 38/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.4165 - accuracy: 0.8792 - val_loss: 0.6327 - val_accuracy: 0.8183\n", "Epoch 39/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.4107 - accuracy: 0.8812 - val_loss: 0.6277 - val_accuracy: 0.8211\n", "Epoch 40/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.4057 - accuracy: 0.8828 - val_loss: 0.6275 - val_accuracy: 0.8206\n", "Epoch 41/200\n", "125/125 [==============================] - 14s 115ms/step - loss: 0.3998 - accuracy: 0.8842 - val_loss: 0.6186 - val_accuracy: 0.8234\n", "Epoch 42/200\n", "125/125 [==============================] - 14s 111ms/step - loss: 0.3948 - accuracy: 0.8852 - val_loss: 0.6186 - val_accuracy: 0.8238\n", "Epoch 43/200\n", "125/125 [==============================] - 14s 112ms/step - loss: 0.3894 - accuracy: 0.8873 - val_loss: 0.6225 - val_accuracy: 0.8232\n", "Epoch 44/200\n", "125/125 [==============================] - 15s 119ms/step - loss: 0.3842 - accuracy: 0.8887 - val_loss: 0.6209 - val_accuracy: 0.8237\n", "Epoch 45/200\n", "125/125 [==============================] - 15s 121ms/step - loss: 0.3797 - accuracy: 0.8898 - val_loss: 0.6139 - val_accuracy: 0.8262\n", "Epoch 46/200\n", "125/125 [==============================] - 15s 124ms/step - loss: 0.3745 - accuracy: 0.8918 - val_loss: 0.6112 - val_accuracy: 0.8271\n", "Epoch 47/200\n", "125/125 [==============================] - 15s 118ms/step - loss: 0.3695 - accuracy: 0.8931 - val_loss: 0.6122 - val_accuracy: 0.8265\n", "Epoch 48/200\n", "125/125 [==============================] - 15s 122ms/step - loss: 0.3641 - accuracy: 0.8943 - val_loss: 0.6184 - val_accuracy: 0.8256\n", "Epoch 49/200\n", "125/125 [==============================] - 15s 117ms/step - loss: 0.3599 - accuracy: 0.8960 - val_loss: 0.6116 - val_accuracy: 0.8274\n", "Epoch 50/200\n", "125/125 [==============================] - 15s 121ms/step - loss: 0.3552 - accuracy: 0.8972 - val_loss: 0.6109 - val_accuracy: 0.8279\n", "Epoch 51/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.3505 - accuracy: 0.8986 - val_loss: 0.6106 - val_accuracy: 0.8289\n", "Epoch 52/200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "125/125 [==============================] - 15s 122ms/step - loss: 0.3459 - accuracy: 0.9002 - val_loss: 0.6133 - val_accuracy: 0.8282\n", "Epoch 53/200\n", "125/125 [==============================] - 14s 115ms/step - loss: 0.3415 - accuracy: 0.9015 - val_loss: 0.6171 - val_accuracy: 0.8282\n", "Epoch 54/200\n", "125/125 [==============================] - 15s 119ms/step - loss: 0.3375 - accuracy: 0.9020 - val_loss: 0.6109 - val_accuracy: 0.8300\n", "Epoch 55/200\n", "125/125 [==============================] - 18s 143ms/step - loss: 0.3326 - accuracy: 0.9039 - val_loss: 0.6103 - val_accuracy: 0.8302\n", "Epoch 56/200\n", "125/125 [==============================] - 16s 130ms/step - loss: 0.3283 - accuracy: 0.9047 - val_loss: 0.6132 - val_accuracy: 0.8305\n", "Epoch 57/200\n", "125/125 [==============================] - 16s 132ms/step - loss: 0.3243 - accuracy: 0.9062 - val_loss: 0.6136 - val_accuracy: 0.8303\n", "Epoch 58/200\n", "125/125 [==============================] - 15s 122ms/step - loss: 0.3195 - accuracy: 0.9077 - val_loss: 0.6206 - val_accuracy: 0.8295\n", "Epoch 59/200\n", "125/125 [==============================] - 16s 131ms/step - loss: 0.3152 - accuracy: 0.9093 - val_loss: 0.6196 - val_accuracy: 0.8305\n", "Epoch 60/200\n", "125/125 [==============================] - 16s 126ms/step - loss: 0.3114 - accuracy: 0.9098 - val_loss: 0.6217 - val_accuracy: 0.8291\n", "Epoch 61/200\n", "125/125 [==============================] - 16s 127ms/step - loss: 0.3073 - accuracy: 0.9113 - val_loss: 0.6295 - val_accuracy: 0.8282\n", "Epoch 62/200\n", "125/125 [==============================] - 16s 125ms/step - loss: 0.3025 - accuracy: 0.9124 - val_loss: 0.6231 - val_accuracy: 0.8292\n", "Epoch 63/200\n", "125/125 [==============================] - 16s 127ms/step - loss: 0.2990 - accuracy: 0.9134 - val_loss: 0.6255 - val_accuracy: 0.8305\n", "Epoch 64/200\n", "125/125 [==============================] - 16s 129ms/step - loss: 0.2956 - accuracy: 0.9147 - val_loss: 0.6279 - val_accuracy: 0.8302\n", "Epoch 65/200\n", "125/125 [==============================] - 16s 125ms/step - loss: 0.2910 - accuracy: 0.9158 - val_loss: 0.6302 - val_accuracy: 0.8302\n", "Epoch 66/200\n", "125/125 [==============================] - 17s 132ms/step - loss: 0.2868 - accuracy: 0.9170 - val_loss: 0.6319 - val_accuracy: 0.8300\n", "Epoch 67/200\n", "125/125 [==============================] - 16s 127ms/step - loss: 0.2833 - accuracy: 0.9179 - val_loss: 0.6364 - val_accuracy: 0.8298\n", "Epoch 68/200\n", "125/125 [==============================] - 16s 130ms/step - loss: 0.2795 - accuracy: 0.9189 - val_loss: 0.6327 - val_accuracy: 0.8315\n", "Epoch 69/200\n", "125/125 [==============================] - 16s 131ms/step - loss: 0.2756 - accuracy: 0.9205 - val_loss: 0.6398 - val_accuracy: 0.8302\n", "Epoch 70/200\n", "125/125 [==============================] - 17s 138ms/step - loss: 0.2719 - accuracy: 0.9216 - val_loss: 0.6430 - val_accuracy: 0.8306\n", "Epoch 71/200\n", "125/125 [==============================] - 15s 123ms/step - loss: 0.2685 - accuracy: 0.9225 - val_loss: 0.6439 - val_accuracy: 0.8295\n", "Epoch 72/200\n", "125/125 [==============================] - 16s 126ms/step - loss: 0.2640 - accuracy: 0.9238 - val_loss: 0.6478 - val_accuracy: 0.8296\n", "Epoch 73/200\n", "125/125 [==============================] - 13s 104ms/step - loss: 0.2609 - accuracy: 0.9249 - val_loss: 0.6485 - val_accuracy: 0.8305\n", "Epoch 74/200\n", "125/125 [==============================] - 16s 125ms/step - loss: 0.2568 - accuracy: 0.9261 - val_loss: 0.6519 - val_accuracy: 0.8311\n", "Epoch 75/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.2542 - accuracy: 0.9263 - val_loss: 0.6572 - val_accuracy: 0.8287\n", "Epoch 76/200\n", "125/125 [==============================] - 16s 127ms/step - loss: 0.2502 - accuracy: 0.9282 - val_loss: 0.6625 - val_accuracy: 0.8289\n", "Epoch 77/200\n", "125/125 [==============================] - 14s 110ms/step - loss: 0.2468 - accuracy: 0.9289 - val_loss: 0.6624 - val_accuracy: 0.8300\n", "Epoch 78/200\n", "125/125 [==============================] - 13s 106ms/step - loss: 0.2435 - accuracy: 0.9298 - val_loss: 0.6642 - val_accuracy: 0.8301\n", "Epoch 79/200\n", "125/125 [==============================] - 13s 105ms/step - loss: 0.2401 - accuracy: 0.9307 - val_loss: 0.6743 - val_accuracy: 0.8290\n", "Epoch 80/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.2369 - accuracy: 0.9316 - val_loss: 0.6680 - val_accuracy: 0.8301\n", "Epoch 81/200\n", "125/125 [==============================] - 13s 104ms/step - loss: 0.2333 - accuracy: 0.9326 - val_loss: 0.6743 - val_accuracy: 0.8286\n", "Epoch 82/200\n", "125/125 [==============================] - 15s 119ms/step - loss: 0.2300 - accuracy: 0.9336 - val_loss: 0.6763 - val_accuracy: 0.8292\n", "Epoch 83/200\n", "125/125 [==============================] - 12s 98ms/step - loss: 0.2269 - accuracy: 0.9345 - val_loss: 0.6805 - val_accuracy: 0.8293\n", "Epoch 84/200\n", "125/125 [==============================] - 13s 107ms/step - loss: 0.2238 - accuracy: 0.9357 - val_loss: 0.6812 - val_accuracy: 0.8292\n", "Epoch 85/200\n", "125/125 [==============================] - 15s 118ms/step - loss: 0.2201 - accuracy: 0.9363 - val_loss: 0.6801 - val_accuracy: 0.8305\n", "Epoch 86/200\n", "125/125 [==============================] - 14s 110ms/step - loss: 0.2176 - accuracy: 0.9370 - val_loss: 0.6918 - val_accuracy: 0.8298\n", "Epoch 87/200\n", "125/125 [==============================] - 13s 106ms/step - loss: 0.2142 - accuracy: 0.9384 - val_loss: 0.7037 - val_accuracy: 0.8277\n", "Epoch 88/200\n", "125/125 [==============================] - 13s 104ms/step - loss: 0.2110 - accuracy: 0.9394 - val_loss: 0.7012 - val_accuracy: 0.8290\n", "Epoch 89/200\n", "125/125 [==============================] - 14s 111ms/step - loss: 0.2078 - accuracy: 0.9402 - val_loss: 0.7013 - val_accuracy: 0.8297\n", "Epoch 90/200\n", "125/125 [==============================] - 17s 137ms/step - loss: 0.2052 - accuracy: 0.9409 - val_loss: 0.7069 - val_accuracy: 0.8285\n", "Epoch 91/200\n", "125/125 [==============================] - 19s 150ms/step - loss: 0.2021 - accuracy: 0.9419 - val_loss: 0.7072 - val_accuracy: 0.8285\n", "Epoch 92/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.1989 - accuracy: 0.9425 - val_loss: 0.7139 - val_accuracy: 0.8289\n", "Epoch 93/200\n", "125/125 [==============================] - 14s 113ms/step - loss: 0.1966 - accuracy: 0.9432 - val_loss: 0.7173 - val_accuracy: 0.8288\n", "Epoch 94/200\n", "125/125 [==============================] - 21s 167ms/step - loss: 0.1940 - accuracy: 0.9440 - val_loss: 0.7222 - val_accuracy: 0.8288\n", "Epoch 95/200\n", "125/125 [==============================] - 18s 146ms/step - loss: 0.1911 - accuracy: 0.9449 - val_loss: 0.7298 - val_accuracy: 0.8283\n", "Epoch 96/200\n", "125/125 [==============================] - 18s 146ms/step - loss: 0.1882 - accuracy: 0.9455 - val_loss: 0.7287 - val_accuracy: 0.8279\n", "Epoch 97/200\n", "125/125 [==============================] - 18s 146ms/step - loss: 0.1856 - accuracy: 0.9463 - val_loss: 0.7367 - val_accuracy: 0.8277\n", "Epoch 98/200\n", "125/125 [==============================] - 19s 149ms/step - loss: 0.1827 - accuracy: 0.9476 - val_loss: 0.7370 - val_accuracy: 0.8283\n", "Epoch 99/200\n", "125/125 [==============================] - 19s 149ms/step - loss: 0.1802 - accuracy: 0.9484 - val_loss: 0.7434 - val_accuracy: 0.8282\n", "Epoch 100/200\n", "125/125 [==============================] - 18s 144ms/step - loss: 0.1778 - accuracy: 0.9486 - val_loss: 0.7473 - val_accuracy: 0.8274\n", "Epoch 101/200\n", "125/125 [==============================] - 18s 148ms/step - loss: 0.1751 - accuracy: 0.9497 - val_loss: 0.7547 - val_accuracy: 0.8266\n", "Epoch 102/200\n", "125/125 [==============================] - 18s 140ms/step - loss: 0.1723 - accuracy: 0.9507 - val_loss: 0.7575 - val_accuracy: 0.8266\n", "Epoch 103/200\n", "125/125 [==============================] - 17s 135ms/step - loss: 0.1701 - accuracy: 0.9511 - val_loss: 0.7621 - val_accuracy: 0.8269\n", "Epoch 104/200\n", "125/125 [==============================] - 18s 143ms/step - loss: 0.1677 - accuracy: 0.9519 - val_loss: 0.7708 - val_accuracy: 0.8263\n", "Epoch 105/200\n", "125/125 [==============================] - 18s 145ms/step - loss: 0.1659 - accuracy: 0.9522 - val_loss: 0.7669 - val_accuracy: 0.8277\n", "Epoch 106/200\n", "125/125 [==============================] - 20s 162ms/step - loss: 0.1628 - accuracy: 0.9533 - val_loss: 0.7766 - val_accuracy: 0.8267\n", "Epoch 107/200\n", "125/125 [==============================] - 18s 143ms/step - loss: 0.1606 - accuracy: 0.9538 - val_loss: 0.7791 - val_accuracy: 0.8275\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 108/200\n", "125/125 [==============================] - 18s 145ms/step - loss: 0.1583 - accuracy: 0.9548 - val_loss: 0.7820 - val_accuracy: 0.8268\n", "Epoch 109/200\n", "125/125 [==============================] - 18s 143ms/step - loss: 0.1564 - accuracy: 0.9550 - val_loss: 0.7863 - val_accuracy: 0.8266\n", "Epoch 110/200\n", "125/125 [==============================] - 19s 149ms/step - loss: 0.1544 - accuracy: 0.9558 - val_loss: 0.7927 - val_accuracy: 0.8259\n", "Epoch 111/200\n", "125/125 [==============================] - 18s 146ms/step - loss: 0.1517 - accuracy: 0.9566 - val_loss: 0.7972 - val_accuracy: 0.8257\n", "Epoch 112/200\n", "125/125 [==============================] - 19s 150ms/step - loss: 0.1490 - accuracy: 0.9573 - val_loss: 0.8031 - val_accuracy: 0.8262\n", "Epoch 113/200\n", "125/125 [==============================] - 21s 172ms/step - loss: 0.1478 - accuracy: 0.9575 - val_loss: 0.8029 - val_accuracy: 0.8261\n", "Epoch 114/200\n", "125/125 [==============================] - 27s 215ms/step - loss: 0.1459 - accuracy: 0.9580 - val_loss: 0.8135 - val_accuracy: 0.8256\n", "Epoch 115/200\n", "125/125 [==============================] - 19s 154ms/step - loss: 0.1437 - accuracy: 0.9585 - val_loss: 0.8143 - val_accuracy: 0.8267\n", "Epoch 116/200\n", "125/125 [==============================] - 14s 113ms/step - loss: 0.1415 - accuracy: 0.9591 - val_loss: 0.8153 - val_accuracy: 0.8260\n", "Epoch 117/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.1394 - accuracy: 0.9601 - val_loss: 0.8186 - val_accuracy: 0.8257\n", "Epoch 118/200\n", "125/125 [==============================] - 14s 108ms/step - loss: 0.1377 - accuracy: 0.9602 - val_loss: 0.8334 - val_accuracy: 0.8253\n", "Epoch 119/200\n", "125/125 [==============================] - 14s 109ms/step - loss: 0.1355 - accuracy: 0.9612 - val_loss: 0.8374 - val_accuracy: 0.8251\n", "Epoch 120/200\n", "125/125 [==============================] - 14s 110ms/step - loss: 0.1335 - accuracy: 0.9619 - val_loss: 0.8370 - val_accuracy: 0.8250\n", "Epoch 121/200\n", "125/125 [==============================] - 16s 132ms/step - loss: 0.1317 - accuracy: 0.9620 - val_loss: 0.8438 - val_accuracy: 0.8245\n", "Epoch 122/200\n", "125/125 [==============================] - 17s 135ms/step - loss: 0.1302 - accuracy: 0.9626 - val_loss: 0.8457 - val_accuracy: 0.8246\n", "Epoch 123/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.1287 - accuracy: 0.9628 - val_loss: 0.8478 - val_accuracy: 0.8246\n", "Epoch 124/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.1265 - accuracy: 0.9636 - val_loss: 0.8581 - val_accuracy: 0.8247\n", "Epoch 125/200\n", "125/125 [==============================] - 12s 97ms/step - loss: 0.1247 - accuracy: 0.9642 - val_loss: 0.8579 - val_accuracy: 0.8244\n", "Epoch 126/200\n", "125/125 [==============================] - 14s 112ms/step - loss: 0.1237 - accuracy: 0.9643 - val_loss: 0.8601 - val_accuracy: 0.8256\n", "Epoch 127/200\n", "125/125 [==============================] - 17s 136ms/step - loss: 0.1226 - accuracy: 0.9646 - val_loss: 0.8719 - val_accuracy: 0.8234\n", "Epoch 128/200\n", "125/125 [==============================] - 19s 149ms/step - loss: 0.1203 - accuracy: 0.9655 - val_loss: 0.8740 - val_accuracy: 0.8233\n", "Epoch 129/200\n", "125/125 [==============================] - 18s 142ms/step - loss: 0.1183 - accuracy: 0.9662 - val_loss: 0.8812 - val_accuracy: 0.8233\n", "Epoch 130/200\n", "125/125 [==============================] - 15s 122ms/step - loss: 0.1162 - accuracy: 0.9669 - val_loss: 0.8817 - val_accuracy: 0.8233\n", "Epoch 131/200\n", "125/125 [==============================] - 16s 131ms/step - loss: 0.1148 - accuracy: 0.9672 - val_loss: 0.8795 - val_accuracy: 0.8242\n", "Epoch 132/200\n", "125/125 [==============================] - 17s 139ms/step - loss: 0.1132 - accuracy: 0.9677 - val_loss: 0.8941 - val_accuracy: 0.8234\n", "Epoch 133/200\n", "125/125 [==============================] - 14s 116ms/step - loss: 0.1118 - accuracy: 0.9679 - val_loss: 0.8929 - val_accuracy: 0.8233\n", "Epoch 134/200\n", "125/125 [==============================] - 16s 126ms/step - loss: 0.1101 - accuracy: 0.9686 - val_loss: 0.8995 - val_accuracy: 0.8236\n", "Epoch 135/200\n", "125/125 [==============================] - 16s 126ms/step - loss: 0.1089 - accuracy: 0.9688 - val_loss: 0.9013 - val_accuracy: 0.8231\n", "Epoch 136/200\n", "125/125 [==============================] - 15s 116ms/step - loss: 0.1078 - accuracy: 0.9690 - val_loss: 0.9056 - val_accuracy: 0.8225\n", "Epoch 137/200\n", "125/125 [==============================] - 13s 106ms/step - loss: 0.1061 - accuracy: 0.9697 - val_loss: 0.9144 - val_accuracy: 0.8223\n", "Epoch 138/200\n", "125/125 [==============================] - 16s 128ms/step - loss: 0.1040 - accuracy: 0.9706 - val_loss: 0.9144 - val_accuracy: 0.8227\n", "Epoch 139/200\n", "125/125 [==============================] - 15s 120ms/step - loss: 0.1038 - accuracy: 0.9703 - val_loss: 0.9172 - val_accuracy: 0.8239\n", "Epoch 140/200\n", "125/125 [==============================] - 15s 117ms/step - loss: 0.1018 - accuracy: 0.9712 - val_loss: 0.9252 - val_accuracy: 0.8225\n", "Epoch 141/200\n", "125/125 [==============================] - 15s 116ms/step - loss: 0.1005 - accuracy: 0.9714 - val_loss: 0.9298 - val_accuracy: 0.8234\n", "Epoch 142/200\n", "125/125 [==============================] - 13s 105ms/step - loss: 0.0991 - accuracy: 0.9717 - val_loss: 0.9344 - val_accuracy: 0.8219\n", "Epoch 143/200\n", "125/125 [==============================] - 13s 102ms/step - loss: 0.0972 - accuracy: 0.9724 - val_loss: 0.9379 - val_accuracy: 0.8222\n", "Epoch 144/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0964 - accuracy: 0.9727 - val_loss: 0.9478 - val_accuracy: 0.8226\n", "Epoch 145/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0961 - accuracy: 0.9725 - val_loss: 0.9507 - val_accuracy: 0.8222\n", "Epoch 146/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0942 - accuracy: 0.9731 - val_loss: 0.9532 - val_accuracy: 0.8219\n", "Epoch 147/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0934 - accuracy: 0.9736 - val_loss: 0.9513 - val_accuracy: 0.8217\n", "Epoch 148/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0917 - accuracy: 0.9740 - val_loss: 0.9572 - val_accuracy: 0.8225\n", "Epoch 149/200\n", "125/125 [==============================] - 13s 104ms/step - loss: 0.0900 - accuracy: 0.9744 - val_loss: 0.9587 - val_accuracy: 0.8223\n", "Epoch 150/200\n", "125/125 [==============================] - 13s 104ms/step - loss: 0.0892 - accuracy: 0.9749 - val_loss: 0.9657 - val_accuracy: 0.8225\n", "Epoch 151/200\n", "125/125 [==============================] - 13s 102ms/step - loss: 0.0876 - accuracy: 0.9750 - val_loss: 0.9735 - val_accuracy: 0.8203\n", "Epoch 152/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0866 - accuracy: 0.9754 - val_loss: 0.9647 - val_accuracy: 0.8231\n", "Epoch 153/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0852 - accuracy: 0.9759 - val_loss: 0.9787 - val_accuracy: 0.8225\n", "Epoch 154/200\n", "125/125 [==============================] - 13s 100ms/step - loss: 0.0841 - accuracy: 0.9761 - val_loss: 0.9868 - val_accuracy: 0.8217\n", "Epoch 155/200\n", "125/125 [==============================] - 13s 100ms/step - loss: 0.0828 - accuracy: 0.9765 - val_loss: 0.9908 - val_accuracy: 0.8215\n", "Epoch 156/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0820 - accuracy: 0.9769 - val_loss: 0.9904 - val_accuracy: 0.8228\n", "Epoch 157/200\n", "125/125 [==============================] - 13s 100ms/step - loss: 0.0813 - accuracy: 0.9769 - val_loss: 0.9992 - val_accuracy: 0.8212\n", "Epoch 158/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0802 - accuracy: 0.9772 - val_loss: 1.0037 - val_accuracy: 0.8204\n", "Epoch 159/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0792 - accuracy: 0.9776 - val_loss: 1.0000 - val_accuracy: 0.8209\n", "Epoch 160/200\n", "125/125 [==============================] - 13s 100ms/step - loss: 0.0789 - accuracy: 0.9777 - val_loss: 1.0039 - val_accuracy: 0.8212\n", "Epoch 161/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0774 - accuracy: 0.9780 - val_loss: 1.0144 - val_accuracy: 0.8210\n", "Epoch 162/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0767 - accuracy: 0.9782 - val_loss: 1.0230 - val_accuracy: 0.8202\n", "Epoch 163/200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "125/125 [==============================] - 12s 98ms/step - loss: 0.0751 - accuracy: 0.9787 - val_loss: 1.0155 - val_accuracy: 0.8201\n", "Epoch 164/200\n", "125/125 [==============================] - 12s 97ms/step - loss: 0.0749 - accuracy: 0.9786 - val_loss: 1.0207 - val_accuracy: 0.8214\n", "Epoch 165/200\n", "125/125 [==============================] - 13s 100ms/step - loss: 0.0734 - accuracy: 0.9792 - val_loss: 1.0302 - val_accuracy: 0.8214\n", "Epoch 166/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0725 - accuracy: 0.9796 - val_loss: 1.0320 - val_accuracy: 0.8208\n", "Epoch 167/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0716 - accuracy: 0.9797 - val_loss: 1.0349 - val_accuracy: 0.8209\n", "Epoch 168/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0706 - accuracy: 0.9799 - val_loss: 1.0365 - val_accuracy: 0.8204\n", "Epoch 169/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0694 - accuracy: 0.9803 - val_loss: 1.0432 - val_accuracy: 0.8215\n", "Epoch 170/200\n", "125/125 [==============================] - 12s 98ms/step - loss: 0.0692 - accuracy: 0.9805 - val_loss: 1.0522 - val_accuracy: 0.8202\n", "Epoch 171/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0683 - accuracy: 0.9806 - val_loss: 1.0662 - val_accuracy: 0.8193\n", "Epoch 172/200\n", "125/125 [==============================] - 13s 102ms/step - loss: 0.0676 - accuracy: 0.9808 - val_loss: 1.0561 - val_accuracy: 0.8206\n", "Epoch 173/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0670 - accuracy: 0.9808 - val_loss: 1.0581 - val_accuracy: 0.8207\n", "Epoch 174/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0654 - accuracy: 0.9816 - val_loss: 1.0665 - val_accuracy: 0.8190\n", "Epoch 175/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0643 - accuracy: 0.9816 - val_loss: 1.0628 - val_accuracy: 0.8205\n", "Epoch 176/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0638 - accuracy: 0.9819 - val_loss: 1.0731 - val_accuracy: 0.8192\n", "Epoch 177/200\n", "125/125 [==============================] - 12s 98ms/step - loss: 0.0630 - accuracy: 0.9822 - val_loss: 1.0767 - val_accuracy: 0.8195\n", "Epoch 178/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0631 - accuracy: 0.9821 - val_loss: 1.0787 - val_accuracy: 0.8204\n", "Epoch 179/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0625 - accuracy: 0.9824 - val_loss: 1.0760 - val_accuracy: 0.8201\n", "Epoch 180/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0610 - accuracy: 0.9826 - val_loss: 1.0882 - val_accuracy: 0.8198\n", "Epoch 181/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0601 - accuracy: 0.9832 - val_loss: 1.0883 - val_accuracy: 0.8196\n", "Epoch 182/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0596 - accuracy: 0.9830 - val_loss: 1.0883 - val_accuracy: 0.8208\n", "Epoch 183/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0589 - accuracy: 0.9834 - val_loss: 1.0956 - val_accuracy: 0.8199\n", "Epoch 184/200\n", "125/125 [==============================] - 13s 102ms/step - loss: 0.0582 - accuracy: 0.9836 - val_loss: 1.0977 - val_accuracy: 0.8194\n", "Epoch 185/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0582 - accuracy: 0.9834 - val_loss: 1.0947 - val_accuracy: 0.8197\n", "Epoch 186/200\n", "125/125 [==============================] - 12s 99ms/step - loss: 0.0565 - accuracy: 0.9840 - val_loss: 1.1029 - val_accuracy: 0.8200\n", "Epoch 187/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0560 - accuracy: 0.9842 - val_loss: 1.1064 - val_accuracy: 0.8202\n", "Epoch 188/200\n", "125/125 [==============================] - 13s 100ms/step - loss: 0.0552 - accuracy: 0.9844 - val_loss: 1.1135 - val_accuracy: 0.8193\n", "Epoch 189/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0551 - accuracy: 0.9844 - val_loss: 1.1150 - val_accuracy: 0.8194\n", "Epoch 190/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0538 - accuracy: 0.9847 - val_loss: 1.1186 - val_accuracy: 0.8200\n", "Epoch 191/200\n", "125/125 [==============================] - 13s 103ms/step - loss: 0.0534 - accuracy: 0.9849 - val_loss: 1.1152 - val_accuracy: 0.8191\n", "Epoch 192/200\n", "125/125 [==============================] - 12s 97ms/step - loss: 0.0530 - accuracy: 0.9850 - val_loss: 1.1191 - val_accuracy: 0.8188\n", "Epoch 193/200\n", "125/125 [==============================] - 12s 98ms/step - loss: 0.0530 - accuracy: 0.9849 - val_loss: 1.1290 - val_accuracy: 0.8188\n", "Epoch 194/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0525 - accuracy: 0.9850 - val_loss: 1.1289 - val_accuracy: 0.8190\n", "Epoch 195/200\n", "125/125 [==============================] - 12s 97ms/step - loss: 0.0511 - accuracy: 0.9856 - val_loss: 1.1290 - val_accuracy: 0.8192\n", "Epoch 196/200\n", "125/125 [==============================] - 12s 98ms/step - loss: 0.0507 - accuracy: 0.9856 - val_loss: 1.1378 - val_accuracy: 0.8190\n", "Epoch 197/200\n", "125/125 [==============================] - 13s 101ms/step - loss: 0.0505 - accuracy: 0.9856 - val_loss: 1.1372 - val_accuracy: 0.8192\n", "Epoch 198/200\n", "125/125 [==============================] - 12s 97ms/step - loss: 0.0498 - accuracy: 0.9858 - val_loss: 1.1386 - val_accuracy: 0.8188\n", "Epoch 199/200\n", "125/125 [==============================] - 12s 98ms/step - loss: 0.0484 - accuracy: 0.9864 - val_loss: 1.1423 - val_accuracy: 0.8189\n", "Epoch 200/200\n", "125/125 [==============================] - 12s 100ms/step - loss: 0.0484 - accuracy: 0.9864 - val_loss: 1.1581 - val_accuracy: 0.8176\n" ] } ], "source": [ "model.compile(\n", " optimizer=\"rmsprop\", loss=\"categorical_crossentropy\", metrics=[\"accuracy\"]\n", ")\n", "model.fit(\n", " [encoder_input_data, decoder_input_data],\n", " decoder_target_data,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " validation_split=0.2,\n", ")\n", "\n", "model.save(\"s2s_model.keras\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4zhoiYDkISvy" }, "source": [ "### Run interference" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "cgNB5j6qIWdo" }, "outputs": [], "source": [ "encoder_inputs = model.input[0] # input_1\n", "encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1\n", "encoder_states = [state_h_enc, state_c_enc]\n", "encoder_model = keras.Model(encoder_inputs, encoder_states)\n", "\n", "decoder_inputs = model.input[1] # input_2\n", "decoder_state_input_h = keras.Input(shape=(latent_dim,))\n", "decoder_state_input_c = keras.Input(shape=(latent_dim,))\n", "decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n", "decoder_lstm = model.layers[3]\n", "decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(\n", " decoder_inputs, initial_state=decoder_states_inputs\n", ")\n", "decoder_states = [state_h_dec, state_c_dec]\n", "decoder_dense = model.layers[4]\n", "decoder_outputs = decoder_dense(decoder_outputs)\n", "decoder_model = keras.Model(\n", " [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states\n", ")\n", "\n", "# Reverse-lookup token index to decode sequences back to\n", "# something readable.\n", "reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())\n", "reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())\n", "\n", "\n", "def decode_sequence(input_seq):\n", " # Encode the input as state vectors.\n", " states_value = encoder_model.predict(input_seq, verbose=0)\n", "\n", " # Generate empty target sequence of length 1.\n", " target_seq = np.zeros((1, 1, num_decoder_tokens))\n", " # Populate the first character of target sequence with the start character.\n", " target_seq[0, 0, target_token_index[\"\\t\"]] = 1.0\n", "\n", " # Sampling loop for a batch of sequences\n", " # (to simplify, here we assume a batch of size 1).\n", " stop_condition = False\n", " decoded_sentence = \"\"\n", " while not stop_condition:\n", " output_tokens, h, c = decoder_model.predict(\n", " [target_seq] + states_value, verbose=0\n", " )\n", "\n", " # Sample a token\n", " sampled_token_index = np.argmax(output_tokens[0, -1, :])\n", " sampled_char = reverse_target_char_index[sampled_token_index]\n", " decoded_sentence += sampled_char\n", "\n", " # Exit condition: either hit max length\n", " # or find stop character.\n", " if sampled_char == \"\\n\" or len(decoded_sentence) > max_decoder_seq_length:\n", " stop_condition = True\n", "\n", " # Update the target sequence (of length 1).\n", " target_seq = np.zeros((1, 1, num_decoder_tokens))\n", " target_seq[0, 0, sampled_token_index] = 1.0\n", "\n", " # Update states\n", " states_value = [h, c]\n", " return decoded_sentence" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1GhhZn_uIiVJ", "outputId": "6930ea66-0a44-4975-b613-00464dc31179" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-\n", "Input sentence: Go.\n", "Decoded sentence: Idź.\n", "\n", "-\n", "Input sentence: Hi.\n", "Decoded sentence: Cześć.\n", "\n", "-\n", "Input sentence: Run!\n", "Decoded sentence: Uciekaj!\n", "\n", "-\n", "Input sentence: Run.\n", "Decoded sentence: Uciekaj.\n", "\n", "-\n", "Input sentence: Run.\n", "Decoded sentence: Uciekaj.\n", "\n", "-\n", "Input sentence: Who?\n", "Decoded sentence: Kto?\n", "\n", "-\n", "Input sentence: Wow!\n", "Decoded sentence: O, dziama?\n", "\n", "-\n", "Input sentence: Wow!\n", "Decoded sentence: O, dziama?\n", "\n", "-\n", "Input sentence: Duck!\n", "Decoded sentence: Unik!\n", "\n", "-\n", "Input sentence: Fire!\n", "Decoded sentence: Strzelaj!\n", "\n", "-\n", "Input sentence: Fire!\n", "Decoded sentence: Strzelaj!\n", "\n", "-\n", "Input sentence: Fire!\n", "Decoded sentence: Strzelaj!\n", "\n", "-\n", "Input sentence: Help!\n", "Decoded sentence: Pomocy!\n", "\n", "-\n", "Input sentence: Hide.\n", "Decoded sentence: Schowaj się.\n", "\n", "-\n", "Input sentence: Jump!\n", "Decoded sentence: Skacz!\n", "\n", "-\n", "Input sentence: Jump.\n", "Decoded sentence: Skok.\n", "\n", "-\n", "Input sentence: Stay.\n", "Decoded sentence: Zostań.\n", "\n", "-\n", "Input sentence: Stop!\n", "Decoded sentence: Zatrzymaj się!\n", "\n", "-\n", "Input sentence: Stop!\n", "Decoded sentence: Zatrzymaj się!\n", "\n", "-\n", "Input sentence: Wait!\n", "Decoded sentence: Zaczekajcie!\n", "\n" ] } ], "source": [ "for seq_index in range(20):\n", " # Take one sequence (part of the training set)\n", " # for trying out decoding.\n", " input_seq = encoder_input_data[seq_index : seq_index + 1]\n", " decoded_sentence = decode_sequence(input_seq)\n", " print(\"-\")\n", " print(\"Input sentence:\", input_texts[seq_index])\n", " print(\"Decoded sentence:\", decoded_sentence)" ] }, { "cell_type": "markdown", "metadata": { "id": "nLylna2WI5VY" }, "source": [ "### Bleu evaluation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "background_save": true, "base_uri": "https://localhost:8080/" }, "id": "YOYWSVlpJe_M", "outputId": "07ca0d9f-1729-42f5-939d-a9616fc127cc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.5965545320892247\n", "0.5981615012338071\n", "0.5930356751147987\n", "0.5942200927187505\n", "0.5951351487924014\n", "0.5919465093543784\n", "0.590832790904886\n", "0.5919919489002385\n", "0.5452219773056319\n", "BLEU score: 0.5095920718253371\n" ] } ], "source": [ "import nltk.translate.bleu_score as bleu\n", "smoothing_function = bleu.SmoothingFunction().method4\n", "bleu_scores = 0\n", "for seq_index in range(len(encoder_input_data)):\n", " input_seq = encoder_input_data[seq_index : seq_index + 1]\n", " decoded_sentence = decode_sequence(input_seq)\n", " reference = [list(target_texts[seq_index])]\n", " candidate = list(decoded_sentence.strip())\n", " bleu_score = bleu.sentence_bleu(reference, candidate, smoothing_function=smoothing_function)\n", " bleu_scores += bleu_score\n", " if seq_index % (len(encoder_input_data)//10) == 0 and seq_index != 0:\n", " print(bleu_scores/seq_index)\n", "bleu_final = bleu_scores/len(encoder_input_data)\n", "print(\"BLEU score:\", bleu_final)" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 1 }