51 KiB
51 KiB
import numpy as np
import keras
import os
from pathlib import Path
import requests
import zipfile
WARNING:tensorflow:From C:\Users\Pawel\anaconda3\Lib\site-packages\keras\src\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.
Download the data
url = "https://www.manythings.org/anki/pol-eng.zip"
local_filename = url.split('/')[-1]
dirpath = Path().absolute()
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
with open(local_filename, 'wb') as f:
f.write(response.content)
with zipfile.ZipFile(local_filename, 'r') as zip_ref:
zip_ref.extractall(dirpath)
Configuration
batch_size = 64 # Batch size for training.
epochs = 200 # Number of epochs to train for.
latent_dim = 256 # Latent dimensionality of the encoding space.
num_samples = 10000 # Number of samples to train on.
# Path to the data txt file on disk.
data_path = os.path.join(dirpath, "pol.txt")
Prepare the data
# Vectorize the data.
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, "r", encoding="utf-8") as f:
lines = f.read().split("\n")
for line in lines[: min(num_samples, len(lines) - 1)]:
input_text, target_text, _ = line.split("\t")
# We use "tab" as the "start sequence" character
# for the targets, and "\n" as "end sequence" character.
target_text = "\t" + target_text + "\n"
input_texts.append(input_text)
target_texts.append(target_text)
for char in input_text:
if char not in input_characters:
input_characters.add(char)
for char in target_text:
if char not in target_characters:
target_characters.add(char)
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])
print("Number of samples:", len(input_texts))
print("Number of unique input tokens:", num_encoder_tokens)
print("Number of unique output tokens:", num_decoder_tokens)
print("Max sequence length for inputs:", max_encoder_seq_length)
print("Max sequence length for outputs:", max_decoder_seq_length)
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
encoder_input_data = np.zeros(
(len(input_texts), max_encoder_seq_length, num_encoder_tokens),
dtype="float32",
)
decoder_input_data = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype="float32",
)
decoder_target_data = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype="float32",
)
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
for t, char in enumerate(input_text):
encoder_input_data[i, t, input_token_index[char]] = 1.0
encoder_input_data[i, t + 1 :, input_token_index[" "]] = 1.0
for t, char in enumerate(target_text):
# decoder_target_data is ahead of decoder_input_data by one timestep
decoder_input_data[i, t, target_token_index[char]] = 1.0
if t > 0:
# decoder_target_data will be ahead by one timestep
# and will not include the start character.
decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
decoder_input_data[i, t + 1 :, target_token_index[" "]] = 1.0
decoder_target_data[i, t:, target_token_index[" "]] = 1.0
Number of samples: 10000 Number of unique input tokens: 70 Number of unique output tokens: 87 Max sequence length for inputs: 19 Max sequence length for outputs: 52
Build the model
# Define an input sequence and process it.
encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
encoder = keras.layers.LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
decoder_outputs = decoder_dense(decoder_outputs)
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
WARNING:tensorflow:From C:\Users\Pawel\anaconda3\Lib\site-packages\keras\src\backend.py:1398: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.
Train the model
model.compile(
optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"]
)
model.fit(
[encoder_input_data, decoder_input_data],
decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2,
)
model.save("s2s_model.keras")
WARNING:tensorflow:From C:\Users\Pawel\anaconda3\Lib\site-packages\keras\src\optimizers\__init__.py:309: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead. Epoch 1/200 WARNING:tensorflow:From C:\Users\Pawel\anaconda3\Lib\site-packages\keras\src\utils\tf_utils.py:492: The name tf.ragged.RaggedTensorValue is deprecated. Please use tf.compat.v1.ragged.RaggedTensorValue instead. WARNING:tensorflow:From C:\Users\Pawel\anaconda3\Lib\site-packages\keras\src\engine\base_layer_utils.py:384: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead. 125/125 [==============================] - 20s 121ms/step - loss: 1.3745 - accuracy: 0.6987 - val_loss: 1.4872 - val_accuracy: 0.6530 Epoch 2/200 125/125 [==============================] - 12s 94ms/step - loss: 1.0822 - accuracy: 0.7185 - val_loss: 1.2406 - val_accuracy: 0.6688 Epoch 3/200 125/125 [==============================] - 14s 109ms/step - loss: 0.9703 - accuracy: 0.7424 - val_loss: 1.1089 - val_accuracy: 0.7065 Epoch 4/200 125/125 [==============================] - 14s 109ms/step - loss: 0.8733 - accuracy: 0.7671 - val_loss: 1.0149 - val_accuracy: 0.7165 Epoch 5/200 125/125 [==============================] - 18s 145ms/step - loss: 0.7839 - accuracy: 0.7835 - val_loss: 0.9353 - val_accuracy: 0.7399 Epoch 6/200 125/125 [==============================] - 16s 129ms/step - loss: 0.7367 - accuracy: 0.7946 - val_loss: 0.8968 - val_accuracy: 0.7464 Epoch 7/200 125/125 [==============================] - 15s 121ms/step - loss: 0.7030 - accuracy: 0.8006 - val_loss: 0.8677 - val_accuracy: 0.7572 Epoch 8/200 125/125 [==============================] - 14s 112ms/step - loss: 0.6812 - accuracy: 0.8056 - val_loss: 0.8536 - val_accuracy: 0.7566 Epoch 9/200 125/125 [==============================] - 14s 114ms/step - loss: 0.6642 - accuracy: 0.8100 - val_loss: 0.8439 - val_accuracy: 0.7569 Epoch 10/200 125/125 [==============================] - 15s 117ms/step - loss: 0.6492 - accuracy: 0.8133 - val_loss: 0.8181 - val_accuracy: 0.7648 Epoch 11/200 125/125 [==============================] - 17s 136ms/step - loss: 0.6369 - accuracy: 0.8162 - val_loss: 0.8118 - val_accuracy: 0.7664 Epoch 12/200 125/125 [==============================] - 14s 115ms/step - loss: 0.6252 - accuracy: 0.8193 - val_loss: 0.7932 - val_accuracy: 0.7709 Epoch 13/200 125/125 [==============================] - 15s 124ms/step - loss: 0.6147 - accuracy: 0.8216 - val_loss: 0.7842 - val_accuracy: 0.7722 Epoch 14/200 125/125 [==============================] - 14s 109ms/step - loss: 0.6040 - accuracy: 0.8244 - val_loss: 0.7801 - val_accuracy: 0.7749 Epoch 15/200 125/125 [==============================] - 14s 109ms/step - loss: 0.5944 - accuracy: 0.8273 - val_loss: 0.7741 - val_accuracy: 0.7758 Epoch 16/200 125/125 [==============================] - 13s 108ms/step - loss: 0.5841 - accuracy: 0.8302 - val_loss: 0.7631 - val_accuracy: 0.7788 Epoch 17/200 125/125 [==============================] - 14s 109ms/step - loss: 0.5739 - accuracy: 0.8337 - val_loss: 0.7413 - val_accuracy: 0.7838 Epoch 18/200 125/125 [==============================] - 14s 108ms/step - loss: 0.5627 - accuracy: 0.8363 - val_loss: 0.7355 - val_accuracy: 0.7862 Epoch 19/200 125/125 [==============================] - 14s 110ms/step - loss: 0.5532 - accuracy: 0.8390 - val_loss: 0.7239 - val_accuracy: 0.7899 Epoch 20/200 125/125 [==============================] - 15s 121ms/step - loss: 0.5441 - accuracy: 0.8418 - val_loss: 0.7179 - val_accuracy: 0.7916 Epoch 21/200 125/125 [==============================] - 15s 117ms/step - loss: 0.5360 - accuracy: 0.8437 - val_loss: 0.7164 - val_accuracy: 0.7915 Epoch 22/200 125/125 [==============================] - 15s 123ms/step - loss: 0.5270 - accuracy: 0.8469 - val_loss: 0.7067 - val_accuracy: 0.7947 Epoch 23/200 125/125 [==============================] - 14s 116ms/step - loss: 0.5186 - accuracy: 0.8493 - val_loss: 0.7051 - val_accuracy: 0.7962 Epoch 24/200 125/125 [==============================] - 14s 109ms/step - loss: 0.5107 - accuracy: 0.8519 - val_loss: 0.6904 - val_accuracy: 0.7994 Epoch 25/200 125/125 [==============================] - 18s 142ms/step - loss: 0.5022 - accuracy: 0.8542 - val_loss: 0.6898 - val_accuracy: 0.7998 Epoch 26/200 125/125 [==============================] - 14s 111ms/step - loss: 0.4945 - accuracy: 0.8565 - val_loss: 0.6775 - val_accuracy: 0.8041 Epoch 27/200 125/125 [==============================] - 14s 116ms/step - loss: 0.4875 - accuracy: 0.8584 - val_loss: 0.6717 - val_accuracy: 0.8064 Epoch 28/200 125/125 [==============================] - 14s 114ms/step - loss: 0.4806 - accuracy: 0.8608 - val_loss: 0.6649 - val_accuracy: 0.8077 Epoch 29/200 125/125 [==============================] - 15s 122ms/step - loss: 0.4728 - accuracy: 0.8627 - val_loss: 0.6616 - val_accuracy: 0.8085 Epoch 30/200 125/125 [==============================] - 15s 117ms/step - loss: 0.4659 - accuracy: 0.8649 - val_loss: 0.6574 - val_accuracy: 0.8094 Epoch 31/200 125/125 [==============================] - 15s 120ms/step - loss: 0.4596 - accuracy: 0.8672 - val_loss: 0.6557 - val_accuracy: 0.8113 Epoch 32/200 125/125 [==============================] - 15s 117ms/step - loss: 0.4527 - accuracy: 0.8690 - val_loss: 0.6500 - val_accuracy: 0.8124 Epoch 33/200 125/125 [==============================] - 15s 119ms/step - loss: 0.4464 - accuracy: 0.8706 - val_loss: 0.6459 - val_accuracy: 0.8141 Epoch 34/200 125/125 [==============================] - 14s 111ms/step - loss: 0.4403 - accuracy: 0.8727 - val_loss: 0.6470 - val_accuracy: 0.8141 Epoch 35/200 125/125 [==============================] - 13s 108ms/step - loss: 0.4343 - accuracy: 0.8741 - val_loss: 0.6338 - val_accuracy: 0.8175 Epoch 36/200 125/125 [==============================] - 14s 108ms/step - loss: 0.4282 - accuracy: 0.8757 - val_loss: 0.6314 - val_accuracy: 0.8190 Epoch 37/200 125/125 [==============================] - 14s 108ms/step - loss: 0.4223 - accuracy: 0.8776 - val_loss: 0.6343 - val_accuracy: 0.8183 Epoch 38/200 125/125 [==============================] - 14s 109ms/step - loss: 0.4165 - accuracy: 0.8792 - val_loss: 0.6327 - val_accuracy: 0.8183 Epoch 39/200 125/125 [==============================] - 14s 109ms/step - loss: 0.4107 - accuracy: 0.8812 - val_loss: 0.6277 - val_accuracy: 0.8211 Epoch 40/200 125/125 [==============================] - 14s 116ms/step - loss: 0.4057 - accuracy: 0.8828 - val_loss: 0.6275 - val_accuracy: 0.8206 Epoch 41/200 125/125 [==============================] - 14s 115ms/step - loss: 0.3998 - accuracy: 0.8842 - val_loss: 0.6186 - val_accuracy: 0.8234 Epoch 42/200 125/125 [==============================] - 14s 111ms/step - loss: 0.3948 - accuracy: 0.8852 - val_loss: 0.6186 - val_accuracy: 0.8238 Epoch 43/200 125/125 [==============================] - 14s 112ms/step - loss: 0.3894 - accuracy: 0.8873 - val_loss: 0.6225 - val_accuracy: 0.8232 Epoch 44/200 125/125 [==============================] - 15s 119ms/step - loss: 0.3842 - accuracy: 0.8887 - val_loss: 0.6209 - val_accuracy: 0.8237 Epoch 45/200 125/125 [==============================] - 15s 121ms/step - loss: 0.3797 - accuracy: 0.8898 - val_loss: 0.6139 - val_accuracy: 0.8262 Epoch 46/200 125/125 [==============================] - 15s 124ms/step - loss: 0.3745 - accuracy: 0.8918 - val_loss: 0.6112 - val_accuracy: 0.8271 Epoch 47/200 125/125 [==============================] - 15s 118ms/step - loss: 0.3695 - accuracy: 0.8931 - val_loss: 0.6122 - val_accuracy: 0.8265 Epoch 48/200 125/125 [==============================] - 15s 122ms/step - loss: 0.3641 - accuracy: 0.8943 - val_loss: 0.6184 - val_accuracy: 0.8256 Epoch 49/200 125/125 [==============================] - 15s 117ms/step - loss: 0.3599 - accuracy: 0.8960 - val_loss: 0.6116 - val_accuracy: 0.8274 Epoch 50/200 125/125 [==============================] - 15s 121ms/step - loss: 0.3552 - accuracy: 0.8972 - val_loss: 0.6109 - val_accuracy: 0.8279 Epoch 51/200 125/125 [==============================] - 14s 116ms/step - loss: 0.3505 - accuracy: 0.8986 - val_loss: 0.6106 - val_accuracy: 0.8289 Epoch 52/200 125/125 [==============================] - 15s 122ms/step - loss: 0.3459 - accuracy: 0.9002 - val_loss: 0.6133 - val_accuracy: 0.8282 Epoch 53/200 125/125 [==============================] - 14s 115ms/step - loss: 0.3415 - accuracy: 0.9015 - val_loss: 0.6171 - val_accuracy: 0.8282 Epoch 54/200 125/125 [==============================] - 15s 119ms/step - loss: 0.3375 - accuracy: 0.9020 - val_loss: 0.6109 - val_accuracy: 0.8300 Epoch 55/200 125/125 [==============================] - 18s 143ms/step - loss: 0.3326 - accuracy: 0.9039 - val_loss: 0.6103 - val_accuracy: 0.8302 Epoch 56/200 125/125 [==============================] - 16s 130ms/step - loss: 0.3283 - accuracy: 0.9047 - val_loss: 0.6132 - val_accuracy: 0.8305 Epoch 57/200 125/125 [==============================] - 16s 132ms/step - loss: 0.3243 - accuracy: 0.9062 - val_loss: 0.6136 - val_accuracy: 0.8303 Epoch 58/200 125/125 [==============================] - 15s 122ms/step - loss: 0.3195 - accuracy: 0.9077 - val_loss: 0.6206 - val_accuracy: 0.8295 Epoch 59/200 125/125 [==============================] - 16s 131ms/step - loss: 0.3152 - accuracy: 0.9093 - val_loss: 0.6196 - val_accuracy: 0.8305 Epoch 60/200 125/125 [==============================] - 16s 126ms/step - loss: 0.3114 - accuracy: 0.9098 - val_loss: 0.6217 - val_accuracy: 0.8291 Epoch 61/200 125/125 [==============================] - 16s 127ms/step - loss: 0.3073 - accuracy: 0.9113 - val_loss: 0.6295 - val_accuracy: 0.8282 Epoch 62/200 125/125 [==============================] - 16s 125ms/step - loss: 0.3025 - accuracy: 0.9124 - val_loss: 0.6231 - val_accuracy: 0.8292 Epoch 63/200 125/125 [==============================] - 16s 127ms/step - loss: 0.2990 - accuracy: 0.9134 - val_loss: 0.6255 - val_accuracy: 0.8305 Epoch 64/200 125/125 [==============================] - 16s 129ms/step - loss: 0.2956 - accuracy: 0.9147 - val_loss: 0.6279 - val_accuracy: 0.8302 Epoch 65/200 125/125 [==============================] - 16s 125ms/step - loss: 0.2910 - accuracy: 0.9158 - val_loss: 0.6302 - val_accuracy: 0.8302 Epoch 66/200 125/125 [==============================] - 17s 132ms/step - loss: 0.2868 - accuracy: 0.9170 - val_loss: 0.6319 - val_accuracy: 0.8300 Epoch 67/200 125/125 [==============================] - 16s 127ms/step - loss: 0.2833 - accuracy: 0.9179 - val_loss: 0.6364 - val_accuracy: 0.8298 Epoch 68/200 125/125 [==============================] - 16s 130ms/step - loss: 0.2795 - accuracy: 0.9189 - val_loss: 0.6327 - val_accuracy: 0.8315 Epoch 69/200 125/125 [==============================] - 16s 131ms/step - loss: 0.2756 - accuracy: 0.9205 - val_loss: 0.6398 - val_accuracy: 0.8302 Epoch 70/200 125/125 [==============================] - 17s 138ms/step - loss: 0.2719 - accuracy: 0.9216 - val_loss: 0.6430 - val_accuracy: 0.8306 Epoch 71/200 125/125 [==============================] - 15s 123ms/step - loss: 0.2685 - accuracy: 0.9225 - val_loss: 0.6439 - val_accuracy: 0.8295 Epoch 72/200 125/125 [==============================] - 16s 126ms/step - loss: 0.2640 - accuracy: 0.9238 - val_loss: 0.6478 - val_accuracy: 0.8296 Epoch 73/200 125/125 [==============================] - 13s 104ms/step - loss: 0.2609 - accuracy: 0.9249 - val_loss: 0.6485 - val_accuracy: 0.8305 Epoch 74/200 125/125 [==============================] - 16s 125ms/step - loss: 0.2568 - accuracy: 0.9261 - val_loss: 0.6519 - val_accuracy: 0.8311 Epoch 75/200 125/125 [==============================] - 14s 116ms/step - loss: 0.2542 - accuracy: 0.9263 - val_loss: 0.6572 - val_accuracy: 0.8287 Epoch 76/200 125/125 [==============================] - 16s 127ms/step - loss: 0.2502 - accuracy: 0.9282 - val_loss: 0.6625 - val_accuracy: 0.8289 Epoch 77/200 125/125 [==============================] - 14s 110ms/step - loss: 0.2468 - accuracy: 0.9289 - val_loss: 0.6624 - val_accuracy: 0.8300 Epoch 78/200 125/125 [==============================] - 13s 106ms/step - loss: 0.2435 - accuracy: 0.9298 - val_loss: 0.6642 - val_accuracy: 0.8301 Epoch 79/200 125/125 [==============================] - 13s 105ms/step - loss: 0.2401 - accuracy: 0.9307 - val_loss: 0.6743 - val_accuracy: 0.8290 Epoch 80/200 125/125 [==============================] - 13s 103ms/step - loss: 0.2369 - accuracy: 0.9316 - val_loss: 0.6680 - val_accuracy: 0.8301 Epoch 81/200 125/125 [==============================] - 13s 104ms/step - loss: 0.2333 - accuracy: 0.9326 - val_loss: 0.6743 - val_accuracy: 0.8286 Epoch 82/200 125/125 [==============================] - 15s 119ms/step - loss: 0.2300 - accuracy: 0.9336 - val_loss: 0.6763 - val_accuracy: 0.8292 Epoch 83/200 125/125 [==============================] - 12s 98ms/step - loss: 0.2269 - accuracy: 0.9345 - val_loss: 0.6805 - val_accuracy: 0.8293 Epoch 84/200 125/125 [==============================] - 13s 107ms/step - loss: 0.2238 - accuracy: 0.9357 - val_loss: 0.6812 - val_accuracy: 0.8292 Epoch 85/200 125/125 [==============================] - 15s 118ms/step - loss: 0.2201 - accuracy: 0.9363 - val_loss: 0.6801 - val_accuracy: 0.8305 Epoch 86/200 125/125 [==============================] - 14s 110ms/step - loss: 0.2176 - accuracy: 0.9370 - val_loss: 0.6918 - val_accuracy: 0.8298 Epoch 87/200 125/125 [==============================] - 13s 106ms/step - loss: 0.2142 - accuracy: 0.9384 - val_loss: 0.7037 - val_accuracy: 0.8277 Epoch 88/200 125/125 [==============================] - 13s 104ms/step - loss: 0.2110 - accuracy: 0.9394 - val_loss: 0.7012 - val_accuracy: 0.8290 Epoch 89/200 125/125 [==============================] - 14s 111ms/step - loss: 0.2078 - accuracy: 0.9402 - val_loss: 0.7013 - val_accuracy: 0.8297 Epoch 90/200 125/125 [==============================] - 17s 137ms/step - loss: 0.2052 - accuracy: 0.9409 - val_loss: 0.7069 - val_accuracy: 0.8285 Epoch 91/200 125/125 [==============================] - 19s 150ms/step - loss: 0.2021 - accuracy: 0.9419 - val_loss: 0.7072 - val_accuracy: 0.8285 Epoch 92/200 125/125 [==============================] - 14s 116ms/step - loss: 0.1989 - accuracy: 0.9425 - val_loss: 0.7139 - val_accuracy: 0.8289 Epoch 93/200 125/125 [==============================] - 14s 113ms/step - loss: 0.1966 - accuracy: 0.9432 - val_loss: 0.7173 - val_accuracy: 0.8288 Epoch 94/200 125/125 [==============================] - 21s 167ms/step - loss: 0.1940 - accuracy: 0.9440 - val_loss: 0.7222 - val_accuracy: 0.8288 Epoch 95/200 125/125 [==============================] - 18s 146ms/step - loss: 0.1911 - accuracy: 0.9449 - val_loss: 0.7298 - val_accuracy: 0.8283 Epoch 96/200 125/125 [==============================] - 18s 146ms/step - loss: 0.1882 - accuracy: 0.9455 - val_loss: 0.7287 - val_accuracy: 0.8279 Epoch 97/200 125/125 [==============================] - 18s 146ms/step - loss: 0.1856 - accuracy: 0.9463 - val_loss: 0.7367 - val_accuracy: 0.8277 Epoch 98/200 125/125 [==============================] - 19s 149ms/step - loss: 0.1827 - accuracy: 0.9476 - val_loss: 0.7370 - val_accuracy: 0.8283 Epoch 99/200 125/125 [==============================] - 19s 149ms/step - loss: 0.1802 - accuracy: 0.9484 - val_loss: 0.7434 - val_accuracy: 0.8282 Epoch 100/200 125/125 [==============================] - 18s 144ms/step - loss: 0.1778 - accuracy: 0.9486 - val_loss: 0.7473 - val_accuracy: 0.8274 Epoch 101/200 125/125 [==============================] - 18s 148ms/step - loss: 0.1751 - accuracy: 0.9497 - val_loss: 0.7547 - val_accuracy: 0.8266 Epoch 102/200 125/125 [==============================] - 18s 140ms/step - loss: 0.1723 - accuracy: 0.9507 - val_loss: 0.7575 - val_accuracy: 0.8266 Epoch 103/200 125/125 [==============================] - 17s 135ms/step - loss: 0.1701 - accuracy: 0.9511 - val_loss: 0.7621 - val_accuracy: 0.8269 Epoch 104/200 125/125 [==============================] - 18s 143ms/step - loss: 0.1677 - accuracy: 0.9519 - val_loss: 0.7708 - val_accuracy: 0.8263 Epoch 105/200 125/125 [==============================] - 18s 145ms/step - loss: 0.1659 - accuracy: 0.9522 - val_loss: 0.7669 - val_accuracy: 0.8277 Epoch 106/200 125/125 [==============================] - 20s 162ms/step - loss: 0.1628 - accuracy: 0.9533 - val_loss: 0.7766 - val_accuracy: 0.8267 Epoch 107/200 125/125 [==============================] - 18s 143ms/step - loss: 0.1606 - accuracy: 0.9538 - val_loss: 0.7791 - val_accuracy: 0.8275 Epoch 108/200 125/125 [==============================] - 18s 145ms/step - loss: 0.1583 - accuracy: 0.9548 - val_loss: 0.7820 - val_accuracy: 0.8268 Epoch 109/200 125/125 [==============================] - 18s 143ms/step - loss: 0.1564 - accuracy: 0.9550 - val_loss: 0.7863 - val_accuracy: 0.8266 Epoch 110/200 125/125 [==============================] - 19s 149ms/step - loss: 0.1544 - accuracy: 0.9558 - val_loss: 0.7927 - val_accuracy: 0.8259 Epoch 111/200 125/125 [==============================] - 18s 146ms/step - loss: 0.1517 - accuracy: 0.9566 - val_loss: 0.7972 - val_accuracy: 0.8257 Epoch 112/200 125/125 [==============================] - 19s 150ms/step - loss: 0.1490 - accuracy: 0.9573 - val_loss: 0.8031 - val_accuracy: 0.8262 Epoch 113/200 125/125 [==============================] - 21s 172ms/step - loss: 0.1478 - accuracy: 0.9575 - val_loss: 0.8029 - val_accuracy: 0.8261 Epoch 114/200 125/125 [==============================] - 27s 215ms/step - loss: 0.1459 - accuracy: 0.9580 - val_loss: 0.8135 - val_accuracy: 0.8256 Epoch 115/200 125/125 [==============================] - 19s 154ms/step - loss: 0.1437 - accuracy: 0.9585 - val_loss: 0.8143 - val_accuracy: 0.8267 Epoch 116/200 125/125 [==============================] - 14s 113ms/step - loss: 0.1415 - accuracy: 0.9591 - val_loss: 0.8153 - val_accuracy: 0.8260 Epoch 117/200 125/125 [==============================] - 14s 109ms/step - loss: 0.1394 - accuracy: 0.9601 - val_loss: 0.8186 - val_accuracy: 0.8257 Epoch 118/200 125/125 [==============================] - 14s 108ms/step - loss: 0.1377 - accuracy: 0.9602 - val_loss: 0.8334 - val_accuracy: 0.8253 Epoch 119/200 125/125 [==============================] - 14s 109ms/step - loss: 0.1355 - accuracy: 0.9612 - val_loss: 0.8374 - val_accuracy: 0.8251 Epoch 120/200 125/125 [==============================] - 14s 110ms/step - loss: 0.1335 - accuracy: 0.9619 - val_loss: 0.8370 - val_accuracy: 0.8250 Epoch 121/200 125/125 [==============================] - 16s 132ms/step - loss: 0.1317 - accuracy: 0.9620 - val_loss: 0.8438 - val_accuracy: 0.8245 Epoch 122/200 125/125 [==============================] - 17s 135ms/step - loss: 0.1302 - accuracy: 0.9626 - val_loss: 0.8457 - val_accuracy: 0.8246 Epoch 123/200 125/125 [==============================] - 14s 116ms/step - loss: 0.1287 - accuracy: 0.9628 - val_loss: 0.8478 - val_accuracy: 0.8246 Epoch 124/200 125/125 [==============================] - 12s 100ms/step - loss: 0.1265 - accuracy: 0.9636 - val_loss: 0.8581 - val_accuracy: 0.8247 Epoch 125/200 125/125 [==============================] - 12s 97ms/step - loss: 0.1247 - accuracy: 0.9642 - val_loss: 0.8579 - val_accuracy: 0.8244 Epoch 126/200 125/125 [==============================] - 14s 112ms/step - loss: 0.1237 - accuracy: 0.9643 - val_loss: 0.8601 - val_accuracy: 0.8256 Epoch 127/200 125/125 [==============================] - 17s 136ms/step - loss: 0.1226 - accuracy: 0.9646 - val_loss: 0.8719 - val_accuracy: 0.8234 Epoch 128/200 125/125 [==============================] - 19s 149ms/step - loss: 0.1203 - accuracy: 0.9655 - val_loss: 0.8740 - val_accuracy: 0.8233 Epoch 129/200 125/125 [==============================] - 18s 142ms/step - loss: 0.1183 - accuracy: 0.9662 - val_loss: 0.8812 - val_accuracy: 0.8233 Epoch 130/200 125/125 [==============================] - 15s 122ms/step - loss: 0.1162 - accuracy: 0.9669 - val_loss: 0.8817 - val_accuracy: 0.8233 Epoch 131/200 125/125 [==============================] - 16s 131ms/step - loss: 0.1148 - accuracy: 0.9672 - val_loss: 0.8795 - val_accuracy: 0.8242 Epoch 132/200 125/125 [==============================] - 17s 139ms/step - loss: 0.1132 - accuracy: 0.9677 - val_loss: 0.8941 - val_accuracy: 0.8234 Epoch 133/200 125/125 [==============================] - 14s 116ms/step - loss: 0.1118 - accuracy: 0.9679 - val_loss: 0.8929 - val_accuracy: 0.8233 Epoch 134/200 125/125 [==============================] - 16s 126ms/step - loss: 0.1101 - accuracy: 0.9686 - val_loss: 0.8995 - val_accuracy: 0.8236 Epoch 135/200 125/125 [==============================] - 16s 126ms/step - loss: 0.1089 - accuracy: 0.9688 - val_loss: 0.9013 - val_accuracy: 0.8231 Epoch 136/200 125/125 [==============================] - 15s 116ms/step - loss: 0.1078 - accuracy: 0.9690 - val_loss: 0.9056 - val_accuracy: 0.8225 Epoch 137/200 125/125 [==============================] - 13s 106ms/step - loss: 0.1061 - accuracy: 0.9697 - val_loss: 0.9144 - val_accuracy: 0.8223 Epoch 138/200 125/125 [==============================] - 16s 128ms/step - loss: 0.1040 - accuracy: 0.9706 - val_loss: 0.9144 - val_accuracy: 0.8227 Epoch 139/200 125/125 [==============================] - 15s 120ms/step - loss: 0.1038 - accuracy: 0.9703 - val_loss: 0.9172 - val_accuracy: 0.8239 Epoch 140/200 125/125 [==============================] - 15s 117ms/step - loss: 0.1018 - accuracy: 0.9712 - val_loss: 0.9252 - val_accuracy: 0.8225 Epoch 141/200 125/125 [==============================] - 15s 116ms/step - loss: 0.1005 - accuracy: 0.9714 - val_loss: 0.9298 - val_accuracy: 0.8234 Epoch 142/200 125/125 [==============================] - 13s 105ms/step - loss: 0.0991 - accuracy: 0.9717 - val_loss: 0.9344 - val_accuracy: 0.8219 Epoch 143/200 125/125 [==============================] - 13s 102ms/step - loss: 0.0972 - accuracy: 0.9724 - val_loss: 0.9379 - val_accuracy: 0.8222 Epoch 144/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0964 - accuracy: 0.9727 - val_loss: 0.9478 - val_accuracy: 0.8226 Epoch 145/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0961 - accuracy: 0.9725 - val_loss: 0.9507 - val_accuracy: 0.8222 Epoch 146/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0942 - accuracy: 0.9731 - val_loss: 0.9532 - val_accuracy: 0.8219 Epoch 147/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0934 - accuracy: 0.9736 - val_loss: 0.9513 - val_accuracy: 0.8217 Epoch 148/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0917 - accuracy: 0.9740 - val_loss: 0.9572 - val_accuracy: 0.8225 Epoch 149/200 125/125 [==============================] - 13s 104ms/step - loss: 0.0900 - accuracy: 0.9744 - val_loss: 0.9587 - val_accuracy: 0.8223 Epoch 150/200 125/125 [==============================] - 13s 104ms/step - loss: 0.0892 - accuracy: 0.9749 - val_loss: 0.9657 - val_accuracy: 0.8225 Epoch 151/200 125/125 [==============================] - 13s 102ms/step - loss: 0.0876 - accuracy: 0.9750 - val_loss: 0.9735 - val_accuracy: 0.8203 Epoch 152/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0866 - accuracy: 0.9754 - val_loss: 0.9647 - val_accuracy: 0.8231 Epoch 153/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0852 - accuracy: 0.9759 - val_loss: 0.9787 - val_accuracy: 0.8225 Epoch 154/200 125/125 [==============================] - 13s 100ms/step - loss: 0.0841 - accuracy: 0.9761 - val_loss: 0.9868 - val_accuracy: 0.8217 Epoch 155/200 125/125 [==============================] - 13s 100ms/step - loss: 0.0828 - accuracy: 0.9765 - val_loss: 0.9908 - val_accuracy: 0.8215 Epoch 156/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0820 - accuracy: 0.9769 - val_loss: 0.9904 - val_accuracy: 0.8228 Epoch 157/200 125/125 [==============================] - 13s 100ms/step - loss: 0.0813 - accuracy: 0.9769 - val_loss: 0.9992 - val_accuracy: 0.8212 Epoch 158/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0802 - accuracy: 0.9772 - val_loss: 1.0037 - val_accuracy: 0.8204 Epoch 159/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0792 - accuracy: 0.9776 - val_loss: 1.0000 - val_accuracy: 0.8209 Epoch 160/200 125/125 [==============================] - 13s 100ms/step - loss: 0.0789 - accuracy: 0.9777 - val_loss: 1.0039 - val_accuracy: 0.8212 Epoch 161/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0774 - accuracy: 0.9780 - val_loss: 1.0144 - val_accuracy: 0.8210 Epoch 162/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0767 - accuracy: 0.9782 - val_loss: 1.0230 - val_accuracy: 0.8202 Epoch 163/200 125/125 [==============================] - 12s 98ms/step - loss: 0.0751 - accuracy: 0.9787 - val_loss: 1.0155 - val_accuracy: 0.8201 Epoch 164/200 125/125 [==============================] - 12s 97ms/step - loss: 0.0749 - accuracy: 0.9786 - val_loss: 1.0207 - val_accuracy: 0.8214 Epoch 165/200 125/125 [==============================] - 13s 100ms/step - loss: 0.0734 - accuracy: 0.9792 - val_loss: 1.0302 - val_accuracy: 0.8214 Epoch 166/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0725 - accuracy: 0.9796 - val_loss: 1.0320 - val_accuracy: 0.8208 Epoch 167/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0716 - accuracy: 0.9797 - val_loss: 1.0349 - val_accuracy: 0.8209 Epoch 168/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0706 - accuracy: 0.9799 - val_loss: 1.0365 - val_accuracy: 0.8204 Epoch 169/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0694 - accuracy: 0.9803 - val_loss: 1.0432 - val_accuracy: 0.8215 Epoch 170/200 125/125 [==============================] - 12s 98ms/step - loss: 0.0692 - accuracy: 0.9805 - val_loss: 1.0522 - val_accuracy: 0.8202 Epoch 171/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0683 - accuracy: 0.9806 - val_loss: 1.0662 - val_accuracy: 0.8193 Epoch 172/200 125/125 [==============================] - 13s 102ms/step - loss: 0.0676 - accuracy: 0.9808 - val_loss: 1.0561 - val_accuracy: 0.8206 Epoch 173/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0670 - accuracy: 0.9808 - val_loss: 1.0581 - val_accuracy: 0.8207 Epoch 174/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0654 - accuracy: 0.9816 - val_loss: 1.0665 - val_accuracy: 0.8190 Epoch 175/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0643 - accuracy: 0.9816 - val_loss: 1.0628 - val_accuracy: 0.8205 Epoch 176/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0638 - accuracy: 0.9819 - val_loss: 1.0731 - val_accuracy: 0.8192 Epoch 177/200 125/125 [==============================] - 12s 98ms/step - loss: 0.0630 - accuracy: 0.9822 - val_loss: 1.0767 - val_accuracy: 0.8195 Epoch 178/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0631 - accuracy: 0.9821 - val_loss: 1.0787 - val_accuracy: 0.8204 Epoch 179/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0625 - accuracy: 0.9824 - val_loss: 1.0760 - val_accuracy: 0.8201 Epoch 180/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0610 - accuracy: 0.9826 - val_loss: 1.0882 - val_accuracy: 0.8198 Epoch 181/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0601 - accuracy: 0.9832 - val_loss: 1.0883 - val_accuracy: 0.8196 Epoch 182/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0596 - accuracy: 0.9830 - val_loss: 1.0883 - val_accuracy: 0.8208 Epoch 183/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0589 - accuracy: 0.9834 - val_loss: 1.0956 - val_accuracy: 0.8199 Epoch 184/200 125/125 [==============================] - 13s 102ms/step - loss: 0.0582 - accuracy: 0.9836 - val_loss: 1.0977 - val_accuracy: 0.8194 Epoch 185/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0582 - accuracy: 0.9834 - val_loss: 1.0947 - val_accuracy: 0.8197 Epoch 186/200 125/125 [==============================] - 12s 99ms/step - loss: 0.0565 - accuracy: 0.9840 - val_loss: 1.1029 - val_accuracy: 0.8200 Epoch 187/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0560 - accuracy: 0.9842 - val_loss: 1.1064 - val_accuracy: 0.8202 Epoch 188/200 125/125 [==============================] - 13s 100ms/step - loss: 0.0552 - accuracy: 0.9844 - val_loss: 1.1135 - val_accuracy: 0.8193 Epoch 189/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0551 - accuracy: 0.9844 - val_loss: 1.1150 - val_accuracy: 0.8194 Epoch 190/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0538 - accuracy: 0.9847 - val_loss: 1.1186 - val_accuracy: 0.8200 Epoch 191/200 125/125 [==============================] - 13s 103ms/step - loss: 0.0534 - accuracy: 0.9849 - val_loss: 1.1152 - val_accuracy: 0.8191 Epoch 192/200 125/125 [==============================] - 12s 97ms/step - loss: 0.0530 - accuracy: 0.9850 - val_loss: 1.1191 - val_accuracy: 0.8188 Epoch 193/200 125/125 [==============================] - 12s 98ms/step - loss: 0.0530 - accuracy: 0.9849 - val_loss: 1.1290 - val_accuracy: 0.8188 Epoch 194/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0525 - accuracy: 0.9850 - val_loss: 1.1289 - val_accuracy: 0.8190 Epoch 195/200 125/125 [==============================] - 12s 97ms/step - loss: 0.0511 - accuracy: 0.9856 - val_loss: 1.1290 - val_accuracy: 0.8192 Epoch 196/200 125/125 [==============================] - 12s 98ms/step - loss: 0.0507 - accuracy: 0.9856 - val_loss: 1.1378 - val_accuracy: 0.8190 Epoch 197/200 125/125 [==============================] - 13s 101ms/step - loss: 0.0505 - accuracy: 0.9856 - val_loss: 1.1372 - val_accuracy: 0.8192 Epoch 198/200 125/125 [==============================] - 12s 97ms/step - loss: 0.0498 - accuracy: 0.9858 - val_loss: 1.1386 - val_accuracy: 0.8188 Epoch 199/200 125/125 [==============================] - 12s 98ms/step - loss: 0.0484 - accuracy: 0.9864 - val_loss: 1.1423 - val_accuracy: 0.8189 Epoch 200/200 125/125 [==============================] - 12s 100ms/step - loss: 0.0484 - accuracy: 0.9864 - val_loss: 1.1581 - val_accuracy: 0.8176
Run interference
encoder_inputs = model.input[0] # input_1
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1
encoder_states = [state_h_enc, state_c_enc]
encoder_model = keras.Model(encoder_inputs, encoder_states)
decoder_inputs = model.input[1] # input_2
decoder_state_input_h = keras.Input(shape=(latent_dim,))
decoder_state_input_c = keras.Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.layers[3]
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
decoder_inputs, initial_state=decoder_states_inputs
)
decoder_states = [state_h_dec, state_c_dec]
decoder_dense = model.layers[4]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model(
[decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)
# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
def decode_sequence(input_seq):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq, verbose=0)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index["\t"]] = 1.0
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ""
while not stop_condition:
output_tokens, h, c = decoder_model.predict(
[target_seq] + states_value, verbose=0
)
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.0
# Update states
states_value = [h, c]
return decoded_sentence
for seq_index in range(20):
# Take one sequence (part of the training set)
# for trying out decoding.
input_seq = encoder_input_data[seq_index : seq_index + 1]
decoded_sentence = decode_sequence(input_seq)
print("-")
print("Input sentence:", input_texts[seq_index])
print("Decoded sentence:", decoded_sentence)
- Input sentence: Go. Decoded sentence: Idź. - Input sentence: Hi. Decoded sentence: Cześć. - Input sentence: Run! Decoded sentence: Uciekaj! - Input sentence: Run. Decoded sentence: Uciekaj. - Input sentence: Run. Decoded sentence: Uciekaj. - Input sentence: Who? Decoded sentence: Kto? - Input sentence: Wow! Decoded sentence: O, dziama? - Input sentence: Wow! Decoded sentence: O, dziama? - Input sentence: Duck! Decoded sentence: Unik! - Input sentence: Fire! Decoded sentence: Strzelaj! - Input sentence: Fire! Decoded sentence: Strzelaj! - Input sentence: Fire! Decoded sentence: Strzelaj! - Input sentence: Help! Decoded sentence: Pomocy! - Input sentence: Hide. Decoded sentence: Schowaj się. - Input sentence: Jump! Decoded sentence: Skacz! - Input sentence: Jump. Decoded sentence: Skok. - Input sentence: Stay. Decoded sentence: Zostań. - Input sentence: Stop! Decoded sentence: Zatrzymaj się! - Input sentence: Stop! Decoded sentence: Zatrzymaj się! - Input sentence: Wait! Decoded sentence: Zaczekajcie!
Bleu evaluation
import nltk.translate.bleu_score as bleu
smoothing_function = bleu.SmoothingFunction().method4
bleu_scores = 0
for seq_index in range(len(encoder_input_data)):
input_seq = encoder_input_data[seq_index : seq_index + 1]
decoded_sentence = decode_sequence(input_seq)
reference = [list(target_texts[seq_index])]
candidate = list(decoded_sentence.strip())
bleu_score = bleu.sentence_bleu(reference, candidate, smoothing_function=smoothing_function)
bleu_scores += bleu_score
if seq_index % (len(encoder_input_data)//10) == 0 and seq_index != 0:
print(bleu_scores/seq_index)
bleu_final = bleu_scores/len(encoder_input_data)
print("BLEU score:", bleu_final)
0.5965545320892247 0.5981615012338071 0.5930356751147987 0.5942200927187505 0.5951351487924014 0.5919465093543784 0.590832790904886 0.5919919489002385 0.5452219773056319 BLEU score: 0.5095920718253371