DL_SEQ2SEQ/seq2seq-keras-falied-attempt.ipynb
2024-06-03 06:53:52 +02:00

73 KiB
Raw Permalink Blame History

Imports

import numpy as np
import keras
import os
import tensorflow as tf
from keras.layers import LSTM, Dense, Input
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

Check GPU

physical_devices = tf.config.experimental.list_physical_devices('GPU')
print("Num GPUs Available: ", len(physical_devices))
Num GPUs Available:  1
print(tf.config.list_physical_devices('GPU'))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Define the model parameters

batch_size = 16
epochs = 50
latent_dim = 256
num_samples = 40000
data_path = "pol-eng/pol.txt"

Read the data

input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, "r", encoding="utf-8") as f:
    lines = f.read().split("\n")
for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text, _ = line.split("\t")

    target_text = "\t" + target_text + "\n"
    input_texts.append(input_text)
    target_texts.append(target_text)
    for char in input_text:
        if char not in input_characters:
            input_characters.add(char)
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char)

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

print("Number of samples:", len(input_texts))
print("Number of unique input tokens:", num_encoder_tokens)
print("Number of unique output tokens:", num_decoder_tokens)
print("Max sequence length for inputs:", max_encoder_seq_length)
print("Max sequence length for outputs:", max_decoder_seq_length)
Number of samples: 40000
Number of unique input tokens: 76
Number of unique output tokens: 99
Max sequence length for inputs: 39
Max sequence length for outputs: 68

Define the input and target data

input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
    dtype="float32",
)
decoder_input_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype="float32",
)
decoder_target_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype="float32",
)

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.0
    encoder_input_data[i, t + 1 :, input_token_index[" "]] = 1.0
    for t, char in enumerate(target_text):
        decoder_input_data[i, t, target_token_index[char]] = 1.0
        if t > 0:
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
    decoder_input_data[i, t + 1 :, target_token_index[" "]] = 1.0
    decoder_target_data[i, t:, target_token_index[" "]] = 1.0

Define the model

encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)

encoder_states = [state_h, state_c]

decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation="softmax")
decoder_outputs = decoder_dense(decoder_outputs)

Train the model

model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"])
model.fit(
    [encoder_input_data, decoder_input_data],
    decoder_target_data,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.2,
)

model.save("s2s_model.keras")
Epoch 1/50
2000/2000 [==============================] - 46s 19ms/step - loss: 0.7914 - accuracy: 0.7804 - val_loss: 0.9906 - val_accuracy: 0.7111
Epoch 2/50
2000/2000 [==============================] - 34s 17ms/step - loss: 0.5680 - accuracy: 0.8332 - val_loss: 0.8399 - val_accuracy: 0.7529
Epoch 3/50
2000/2000 [==============================] - 35s 17ms/step - loss: 0.4870 - accuracy: 0.8568 - val_loss: 0.7561 - val_accuracy: 0.7769
Epoch 4/50
2000/2000 [==============================] - 35s 17ms/step - loss: 0.4405 - accuracy: 0.8702 - val_loss: 0.7148 - val_accuracy: 0.7894
Epoch 5/50
2000/2000 [==============================] - 37s 18ms/step - loss: 0.4091 - accuracy: 0.8794 - val_loss: 0.6904 - val_accuracy: 0.7963
Epoch 6/50
2000/2000 [==============================] - 35s 18ms/step - loss: 0.3858 - accuracy: 0.8860 - val_loss: 0.6719 - val_accuracy: 0.8025
Epoch 7/50
2000/2000 [==============================] - 35s 18ms/step - loss: 0.3671 - accuracy: 0.8912 - val_loss: 0.6604 - val_accuracy: 0.8064
Epoch 8/50
2000/2000 [==============================] - 37s 18ms/step - loss: 0.3519 - accuracy: 0.8957 - val_loss: 0.6514 - val_accuracy: 0.8100
Epoch 9/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.3390 - accuracy: 0.8994 - val_loss: 0.6477 - val_accuracy: 0.8112
Epoch 10/50
2000/2000 [==============================] - 35s 17ms/step - loss: 0.3282 - accuracy: 0.9025 - val_loss: 0.6460 - val_accuracy: 0.8125
Epoch 11/50
2000/2000 [==============================] - 35s 17ms/step - loss: 0.3180 - accuracy: 0.9052 - val_loss: 0.6436 - val_accuracy: 0.8137
Epoch 12/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.3092 - accuracy: 0.9079 - val_loss: 0.6467 - val_accuracy: 0.8140
Epoch 13/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.3013 - accuracy: 0.9100 - val_loss: 0.6468 - val_accuracy: 0.8143
Epoch 14/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2941 - accuracy: 0.9124 - val_loss: 0.6478 - val_accuracy: 0.8149
Epoch 15/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2876 - accuracy: 0.9140 - val_loss: 0.6510 - val_accuracy: 0.8154
Epoch 16/50
2000/2000 [==============================] - 38s 19ms/step - loss: 0.2814 - accuracy: 0.9159 - val_loss: 0.6575 - val_accuracy: 0.8156
Epoch 17/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2759 - accuracy: 0.9175 - val_loss: 0.6627 - val_accuracy: 0.8146
Epoch 18/50
2000/2000 [==============================] - 37s 18ms/step - loss: 0.2702 - accuracy: 0.9190 - val_loss: 0.6649 - val_accuracy: 0.8149
Epoch 19/50
2000/2000 [==============================] - 39s 19ms/step - loss: 0.2653 - accuracy: 0.9204 - val_loss: 0.6731 - val_accuracy: 0.8143
Epoch 20/50
2000/2000 [==============================] - 37s 19ms/step - loss: 0.2608 - accuracy: 0.9217 - val_loss: 0.6772 - val_accuracy: 0.8135
Epoch 21/50
2000/2000 [==============================] - 38s 19ms/step - loss: 0.2562 - accuracy: 0.9230 - val_loss: 0.6812 - val_accuracy: 0.8139
Epoch 22/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2524 - accuracy: 0.9242 - val_loss: 0.6815 - val_accuracy: 0.8143
Epoch 23/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2481 - accuracy: 0.9253 - val_loss: 0.6875 - val_accuracy: 0.8135
Epoch 24/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2448 - accuracy: 0.9264 - val_loss: 0.6945 - val_accuracy: 0.8142
Epoch 25/50
2000/2000 [==============================] - 38s 19ms/step - loss: 0.2413 - accuracy: 0.9273 - val_loss: 0.6989 - val_accuracy: 0.8127
Epoch 26/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2379 - accuracy: 0.9282 - val_loss: 0.7044 - val_accuracy: 0.8129
Epoch 27/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2349 - accuracy: 0.9291 - val_loss: 0.7080 - val_accuracy: 0.8122
Epoch 28/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2319 - accuracy: 0.9298 - val_loss: 0.7118 - val_accuracy: 0.8127
Epoch 29/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2289 - accuracy: 0.9309 - val_loss: 0.7144 - val_accuracy: 0.8124
Epoch 30/50
2000/2000 [==============================] - 37s 18ms/step - loss: 0.2263 - accuracy: 0.9315 - val_loss: 0.7215 - val_accuracy: 0.8119
Epoch 31/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2236 - accuracy: 0.9322 - val_loss: 0.7254 - val_accuracy: 0.8123
Epoch 32/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2208 - accuracy: 0.9331 - val_loss: 0.7319 - val_accuracy: 0.8111
Epoch 33/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2188 - accuracy: 0.9334 - val_loss: 0.7390 - val_accuracy: 0.8099
Epoch 34/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2162 - accuracy: 0.9341 - val_loss: 0.7406 - val_accuracy: 0.8107
Epoch 35/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2142 - accuracy: 0.9348 - val_loss: 0.7477 - val_accuracy: 0.8101
Epoch 36/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2122 - accuracy: 0.9356 - val_loss: 0.7504 - val_accuracy: 0.8099
Epoch 37/50
2000/2000 [==============================] - 37s 18ms/step - loss: 0.2101 - accuracy: 0.9360 - val_loss: 0.7528 - val_accuracy: 0.8098
Epoch 38/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2079 - accuracy: 0.9366 - val_loss: 0.7568 - val_accuracy: 0.8103
Epoch 39/50
2000/2000 [==============================] - 37s 19ms/step - loss: 0.2063 - accuracy: 0.9369 - val_loss: 0.7615 - val_accuracy: 0.8090
Epoch 40/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2045 - accuracy: 0.9376 - val_loss: 0.7615 - val_accuracy: 0.8101
Epoch 41/50
2000/2000 [==============================] - 37s 19ms/step - loss: 0.2027 - accuracy: 0.9379 - val_loss: 0.7684 - val_accuracy: 0.8089
Epoch 42/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.2010 - accuracy: 0.9385 - val_loss: 0.7784 - val_accuracy: 0.8080
Epoch 43/50
2000/2000 [==============================] - 39s 20ms/step - loss: 0.1993 - accuracy: 0.9389 - val_loss: 0.7801 - val_accuracy: 0.8083
Epoch 44/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.1978 - accuracy: 0.9393 - val_loss: 0.7780 - val_accuracy: 0.8084
Epoch 45/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.1961 - accuracy: 0.9397 - val_loss: 0.7928 - val_accuracy: 0.8078
Epoch 46/50
2000/2000 [==============================] - 42s 21ms/step - loss: 0.1944 - accuracy: 0.9404 - val_loss: 0.7901 - val_accuracy: 0.8083
Epoch 47/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.1930 - accuracy: 0.9408 - val_loss: 0.7916 - val_accuracy: 0.8080
Epoch 48/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.1918 - accuracy: 0.9410 - val_loss: 0.7999 - val_accuracy: 0.8072
Epoch 49/50
2000/2000 [==============================] - 37s 18ms/step - loss: 0.1902 - accuracy: 0.9413 - val_loss: 0.8061 - val_accuracy: 0.8070
Epoch 50/50
2000/2000 [==============================] - 36s 18ms/step - loss: 0.1894 - accuracy: 0.9416 - val_loss: 0.8051 - val_accuracy: 0.8069

Load model

model = keras.models.load_model("s2s_model.keras")

Encoder model

encoder_inputs = model.input[0]
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output
encoder_states = [state_h_enc, state_c_enc]
encoder_model = keras.Model(encoder_inputs, encoder_states)

Decoder model

decoder_inputs = model.input[1]
decoder_state_input_h = keras.Input(shape=(latent_dim,))
decoder_state_input_c = keras.Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.layers[3]
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs
)
decoder_states = [state_h_dec, state_c_dec]
decoder_dense = model.layers[4]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model(
    [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)

Define the decoding sequence

reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())


def decode_sequence(input_seq):
    states_value = encoder_model.predict(input_seq, verbose=0)

    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0, 0, target_token_index["\t"]] = 1.0

    stop_condition = False
    decoded_sentence = ""
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value, verbose=0
        )

        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
            stop_condition = True

        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.0

        states_value = [h, c]
    return decoded_sentence

Sample predictions

for seq_index in range(20):
    input_seq = encoder_input_data[seq_index : seq_index + 1]
    decoded_sentence = decode_sequence(input_seq)
    print("-")
    print("Input sentence:", input_texts[seq_index])
    print("Decoded sentence:", decoded_sentence)
-
Input sentence: Go.
Decoded sentence: Idź się.

-
Input sentence: Hi.
Decoded sentence: Cześć.

-
Input sentence: Run!
Decoded sentence: Uciekaj!

-
Input sentence: Run.
Decoded sentence: Biegnij.

-
Input sentence: Run.
Decoded sentence: Biegnij.

-
Input sentence: Who?
Decoded sentence: Kto?

-
Input sentence: Wow!
Decoded sentence: Jak sprawdzi!

-
Input sentence: Wow!
Decoded sentence: Jak sprawdzi!

-
Input sentence: Duck!
Decoded sentence: Unik!

-
Input sentence: Fire!
Decoded sentence: Staktowa!

-
Input sentence: Fire!
Decoded sentence: Staktowa!

-
Input sentence: Fire!
Decoded sentence: Staktowa!

-
Input sentence: Help!
Decoded sentence: Pomocy!

-
Input sentence: Hide.
Decoded sentence: Utekty to.

-
Input sentence: Jump!
Decoded sentence: Samujesz!

-
Input sentence: Jump.
Decoded sentence: Skok.

-
Input sentence: Stay.
Decoded sentence: Zostań.

-
Input sentence: Stop!
Decoded sentence: Zaczej się!

-
Input sentence: Stop!
Decoded sentence: Zaczej się!

-
Input sentence: Wait!
Decoded sentence: Czekajcie!

Compute BLEU score

def preprocess_text(text):
    return text.strip().split()

def compute_bleu_score(target_texts, range_limit=1000):
    candidate_corpus = []
    references_corpus = []

    for seq_index in tqdm(range(range_limit), desc="Calculating BLEU scores"):
        input_seq = encoder_input_data[seq_index : seq_index + 1]
        decoded_sentence = decode_sequence(input_seq)
        candidate_corpus.append(preprocess_text(decoded_sentence))
        references_corpus.append([preprocess_text(target_texts[seq_index])])

    smoothie = SmoothingFunction().method4
    bleu_scores = [sentence_bleu(ref, cand, smoothing_function=smoothie) for ref, cand in zip(references_corpus, candidate_corpus)]
    average_bleu = sum(bleu_scores) / len(bleu_scores)

    print("BLEU score:", average_bleu)
compute_bleu_score(target_texts, 100)
Calculating BLEU scores: 100%|██████████| 100/100 [01:12<00:00,  1.39it/s]
BLEU score: 0.2236331802292942
compute_bleu_score(target_texts, 1000)
Calculating BLEU scores: 100%|██████████| 1000/1000 [14:39<00:00,  1.14it/s]
BLEU score: 0.152040789734918
compute_bleu_score(target_texts, 10000)
Calculating BLEU scores:   8%|▊         | 761/10000 [13:31<2:44:18,  1.07s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[19], line 1
----> 1 compute_bleu_score(target_texts, 10000)

Cell In[15], line 13, in compute_bleu_score(target_texts, range_limit)
     11 for seq_index in tqdm(range(range_limit), desc="Calculating BLEU scores"):
     12     input_seq = encoder_input_data[seq_index : seq_index + 1]
---> 13     decoded_sentence = decode_sequence(input_seq)
     14     candidate_corpus.append(preprocess_text(decoded_sentence))
     15     references_corpus.append([preprocess_text(target_texts[seq_index])])

Cell In[14], line 6, in decode_sequence(input_seq)
      5 def decode_sequence(input_seq):
----> 6     states_value = encoder_model.predict(input_seq, verbose=0)
      8     target_seq = np.zeros((1, 1, num_decoder_tokens))
      9     target_seq[0, 0, target_token_index["\t"]] = 1.0

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\keras\utils\traceback_utils.py:65, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     63 filtered_tb = None
     64 try:
---> 65     return fn(*args, **kwargs)
     66 except Exception as e:
     67     filtered_tb = _process_traceback_frames(e.__traceback__)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\keras\engine\training.py:2220, in Model.predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   2211     except ValueError:
   2212         warnings.warn(
   2213             "Using Model.predict with MultiWorkerMirroredStrategy "
   2214             "or TPUStrategy and AutoShardPolicy.FILE might lead to "
   (...)
   2217             stacklevel=2,
   2218         )
-> 2220 data_handler = data_adapter.get_data_handler(
   2221     x=x,
   2222     batch_size=batch_size,
   2223     steps_per_epoch=steps,
   2224     initial_epoch=0,
   2225     epochs=1,
   2226     max_queue_size=max_queue_size,
   2227     workers=workers,
   2228     use_multiprocessing=use_multiprocessing,
   2229     model=self,
   2230     steps_per_execution=self._steps_per_execution,
   2231 )
   2233 # Container that configures and calls `tf.keras.Callback`s.
   2234 if not isinstance(callbacks, callbacks_module.CallbackList):

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\keras\engine\data_adapter.py:1582, in get_data_handler(*args, **kwargs)
   1580 if getattr(kwargs["model"], "_cluster_coordinator", None):
   1581     return _ClusterCoordinatorDataHandler(*args, **kwargs)
-> 1582 return DataHandler(*args, **kwargs)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\keras\engine\data_adapter.py:1262, in DataHandler.__init__(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model, steps_per_execution, distribute)
   1259     self._steps_per_execution = steps_per_execution
   1261 adapter_cls = select_data_adapter(x, y)
-> 1262 self._adapter = adapter_cls(
   1263     x,
   1264     y,
   1265     batch_size=batch_size,
   1266     steps=steps_per_epoch,
   1267     epochs=epochs - initial_epoch,
   1268     sample_weights=sample_weight,
   1269     shuffle=shuffle,
   1270     max_queue_size=max_queue_size,
   1271     workers=workers,
   1272     use_multiprocessing=use_multiprocessing,
   1273     distribution_strategy=tf.distribute.get_strategy(),
   1274     model=model,
   1275 )
   1277 strategy = tf.distribute.get_strategy()
   1279 self._current_step = 0

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\keras\engine\data_adapter.py:347, in TensorLikeDataAdapter.__init__(self, x, y, sample_weights, sample_weight_modes, batch_size, epochs, steps, shuffle, **kwargs)
    344         flat_dataset = flat_dataset.shuffle(1024).repeat(epochs)
    345     return flat_dataset
--> 347 indices_dataset = indices_dataset.flat_map(slice_batch_indices)
    349 dataset = self.slice_inputs(indices_dataset, inputs)
    351 if shuffle == "batch":

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py:2245, in DatasetV2.flat_map(self, map_func, name)
   2212 def flat_map(self, map_func, name=None):
   2213   """Maps `map_func` across this dataset and flattens the result.
   2214 
   2215   The type signature is:
   (...)
   2243     Dataset: A `Dataset`.
   2244   """
-> 2245   return FlatMapDataset(self, map_func, name=name)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py:5484, in FlatMapDataset.__init__(self, input_dataset, map_func, name)
   5482 """See `Dataset.flat_map()` for details."""
   5483 self._input_dataset = input_dataset
-> 5484 self._map_func = structured_function.StructuredFunctionWrapper(
   5485     map_func, self._transformation_name(), dataset=input_dataset)
   5486 if not isinstance(self._map_func.output_structure, DatasetSpec):
   5487   raise TypeError(
   5488       "The `map_func` argument must return a `Dataset` object. Got "
   5489       f"{_get_type(self._map_func.output_structure)!r}.")

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\data\ops\structured_function.py:271, in StructuredFunctionWrapper.__init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
    264       warnings.warn(
    265           "Even though the `tf.config.experimental_run_functions_eagerly` "
    266           "option is set, this option does not apply to tf.data functions. "
    267           "To force eager execution of tf.data functions, please use "
    268           "`tf.data.experimental.enable_debug_mode()`.")
    269     fn_factory = trace_tf_function(defun_kwargs)
--> 271 self._function = fn_factory()
    272 # There is no graph to add in eager mode.
    273 add_to_graph &= not context.executing_eagerly()

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\eager\function.py:2610, in Function.get_concrete_function(self, *args, **kwargs)
   2601 def get_concrete_function(self, *args, **kwargs):
   2602   """Returns a `ConcreteFunction` specialized to inputs and execution context.
   2603 
   2604   Args:
   (...)
   2608        or `tf.Tensor` or `tf.TensorSpec`.
   2609   """
-> 2610   graph_function = self._get_concrete_function_garbage_collected(
   2611       *args, **kwargs)
   2612   graph_function._garbage_collector.release()  # pylint: disable=protected-access
   2613   return graph_function

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\eager\function.py:2576, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
   2574   args, kwargs = None, None
   2575 with self._lock:
-> 2576   graph_function, _ = self._maybe_define_function(args, kwargs)
   2577   seen_names = set()
   2578   captured = object_identity.ObjectIdentitySet(
   2579       graph_function.graph.internal_captures)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\eager\function.py:2760, in Function._maybe_define_function(self, args, kwargs)
   2758   # Only get placeholders for arguments, not captures
   2759   args, kwargs = placeholder_dict["args"]
-> 2760 graph_function = self._create_graph_function(args, kwargs)
   2762 graph_capture_container = graph_function.graph._capture_func_lib  # pylint: disable=protected-access
   2763 # Maintain the list of all captures

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\eager\function.py:2670, in Function._create_graph_function(self, args, kwargs)
   2665 missing_arg_names = [
   2666     "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
   2667 ]
   2668 arg_names = base_arg_names + missing_arg_names
   2669 graph_function = ConcreteFunction(
-> 2670     func_graph_module.func_graph_from_py_func(
   2671         self._name,
   2672         self._python_function,
   2673         args,
   2674         kwargs,
   2675         self.input_signature,
   2676         autograph=self._autograph,
   2677         autograph_options=self._autograph_options,
   2678         arg_names=arg_names,
   2679         capture_by_value=self._capture_by_value),
   2680     self._function_attributes,
   2681     spec=self.function_spec,
   2682     # Tell the ConcreteFunction to clean up its graph once it goes out of
   2683     # scope. This is not the default behavior since it gets used in some
   2684     # places (like Keras) where the FuncGraph lives longer than the
   2685     # ConcreteFunction.
   2686     shared_func_graph=False)
   2687 return graph_function

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\func_graph.py:1251, in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, acd_record_initial_resource_uses)
   1247   func_outputs = python_func(*func_args, **func_kwargs)
   1249   # invariant: `func_outputs` contains only Tensors, CompositeTensors,
   1250   # TensorArrays and `None`s.
-> 1251   func_outputs = nest.map_structure(
   1252       convert, func_outputs, expand_composites=True)
   1254   check_func_mutation(func_args_before, func_kwargs_before, func_args,
   1255                       func_kwargs, original_func)
   1256 finally:

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\util\nest.py:917, in map_structure(func, *structure, **kwargs)
    913 flat_structure = (flatten(s, expand_composites) for s in structure)
    914 entries = zip(*flat_structure)
    916 return pack_sequence_as(
--> 917     structure[0], [func(*x) for x in entries],
    918     expand_composites=expand_composites)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\util\nest.py:917, in <listcomp>(.0)
    913 flat_structure = (flatten(s, expand_composites) for s in structure)
    914 entries = zip(*flat_structure)
    916 return pack_sequence_as(
--> 917     structure[0], [func(*x) for x in entries],
    918     expand_composites=expand_composites)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\func_graph.py:1210, in func_graph_from_py_func.<locals>.convert(x)
   1203     raise TypeError(
   1204         "To be compatible with tf.function, Python functions "
   1205         "must return zero or more Tensors or ExtensionTypes or None "
   1206         f"values; in compilation of {str(python_func)}, found return "
   1207         f"value of type {type(x).__name__}, which is not a Tensor or "
   1208         "ExtensionType.")
   1209 if add_control_dependencies:
-> 1210   x = deps_ctx.mark_as_return(x)
   1211 return x

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\auto_control_deps.py:249, in AutomaticControlDependencies.mark_as_return(self, tensor)
    244   return tensor_array_ops.build_ta_with_new_flow(tensor, flow)
    245 # We want to make the return values depend on the stateful operations, but
    246 # we don't want to introduce a cycle, so we make the return value the result
    247 # of a new identity operation that the stateful operations definitely don't
    248 # depend on.
--> 249 tensor = array_ops.identity(tensor)
    250 self._returned_tensors.add(tensor)
    251 return tensor

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\util\traceback_utils.py:150, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    148 filtered_tb = None
    149 try:
--> 150   return fn(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\util\dispatch.py:1176, in add_dispatch_support.<locals>.decorator.<locals>.op_dispatch_handler(*args, **kwargs)
   1174 # Fallback dispatch system (dispatch v1):
   1175 try:
-> 1176   return dispatch_target(*args, **kwargs)
   1177 except (TypeError, ValueError):
   1178   # Note: convert_to_eager_tensor currently raises a ValueError, not a
   1179   # TypeError, when given unexpected types.  So we need to catch both.
   1180   result = dispatch(op_dispatch_handler, args, kwargs)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\ops\array_ops.py:294, in identity(input, name)
    290 if context.executing_eagerly() and not hasattr(input, "graph"):
    291   # Make sure we get an input with handle data attached from resource
    292   # variables. Variables have correct handle data when graph building.
    293   input = ops.convert_to_tensor(input)
--> 294 ret = gen_array_ops.identity(input, name=name)
    295 # Propagate handle data for happier shape inference for resource variables.
    296 if hasattr(input, "_handle_data"):

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\ops\gen_array_ops.py:4081, in identity(input, name)
   4079     pass  # Add nodes to the TensorFlow graph.
   4080 # Add nodes to the TensorFlow graph.
-> 4081 _, _, _op, _outputs = _op_def_library._apply_op_helper(
   4082       "Identity", input=input, name=name)
   4083 _result = _outputs[:]
   4084 if _execute.must_record_gradient():

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\op_def_library.py:797, in _apply_op_helper(op_type_name, name, **keywords)
    792 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
    793                         if arg.is_ref]
    794 with _MaybeColocateWith(must_colocate_inputs):
    795   # Add Op to graph
    796   # pylint: disable=protected-access
--> 797   op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    798                              name=scope, input_types=input_types,
    799                              attrs=attr_protos, op_def=op_def)
    801 # `outputs` is returned as a separate return value so that the output
    802 # tensors can the `op` per se can be decoupled so that the
    803 # `op_callbacks` can function properly. See framework/op_callbacks.py
    804 # for more details.
    805 outputs = op.outputs

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\func_graph.py:735, in FuncGraph._create_op_internal(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)
    733   inp = self.capture(inp)
    734   captured_inputs.append(inp)
--> 735 return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    736     op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
    737     compute_device)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\ops.py:3800, in Graph._create_op_internal(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)
   3797 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
   3798 # Session.run call cannot occur between creating and mutating the op.
   3799 with self._mutation_lock():
-> 3800   ret = Operation(
   3801       node_def,
   3802       self,
   3803       inputs=inputs,
   3804       output_types=dtypes,
   3805       control_inputs=control_inputs,
   3806       input_types=input_types,
   3807       original_op=self._default_original_op,
   3808       op_def=op_def)
   3809   self._create_op_helper(ret, compute_device=compute_device)
   3810 return ret

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\ops.py:2108, in Operation.__init__(***failed resolving arguments***)
   2105     control_input_ops.append(control_op)
   2107 # Initialize c_op from node_def and other inputs
-> 2108 c_op = _create_c_op(g, node_def, inputs, control_input_ops, op_def=op_def)
   2109 self._init_from_c_op(c_op=c_op, g=g)
   2111 self._original_op = original_op

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\util\traceback_utils.py:150, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    148 filtered_tb = None
    149 try:
--> 150   return fn(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)

File ~\PycharmProjects\pythonProject\venv\lib\site-packages\tensorflow\python\framework\ops.py:1966, in _create_c_op(graph, node_def, inputs, control_inputs, op_def, extract_traceback)
   1962   pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
   1963                                          serialized)
   1965 try:
-> 1966   c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
   1967 except errors.InvalidArgumentError as e:
   1968   # Convert to ValueError for backwards compatibility.
   1969   raise ValueError(e.message)

KeyboardInterrupt: 

Patrząc na czas wykonywania się poszczególnych kroków, uczenie modelu przez kilka godzin, blue dla wszystkich elementów powyżej 8 godzin bez rezultatów, porzuciłem dalsze próby i przeszedłem na rozwiązanie na pyTorch, które wykonało się diametralnie szybciej (oba na cuda)