151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
"""Bert model.
|
|
|
|
Adapted from https://keras.io/examples/nlp/masked_language_modeling/
|
|
"""
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow import keras
|
|
|
|
from keras.integration_test.models.input_spec import InputSpec
|
|
|
|
SEQUENCE_LENGTH = 16
|
|
VOCAB_SIZE = 1000
|
|
EMBED_DIM = 64
|
|
NUM_HEAD = 2
|
|
FF_DIM = 32
|
|
NUM_LAYERS = 2
|
|
|
|
|
|
def get_data_spec(batch_size):
|
|
return (
|
|
InputSpec((batch_size,), dtype="string"),
|
|
InputSpec((batch_size, SEQUENCE_LENGTH, VOCAB_SIZE)),
|
|
)
|
|
|
|
|
|
def get_input_preprocessor():
|
|
input_vectorizer = keras.layers.TextVectorization(
|
|
max_tokens=VOCAB_SIZE,
|
|
output_mode="int",
|
|
output_sequence_length=SEQUENCE_LENGTH,
|
|
)
|
|
text_ds = tf.data.Dataset.from_tensor_slices(
|
|
[
|
|
"Lorem ipsum dolor sit amet",
|
|
"consectetur adipiscing elit",
|
|
"sed do eiusmod tempor incididunt ut",
|
|
"labore et dolore magna aliqua.",
|
|
"Ut enim ad minim veniam",
|
|
"quis nostrud exercitation ullamco",
|
|
"laboris nisi ut aliquip ex ea commodo consequat.",
|
|
]
|
|
)
|
|
input_vectorizer.adapt(text_ds)
|
|
return input_vectorizer
|
|
|
|
|
|
def bert_module(query, key, value, i):
|
|
attention_output = keras.layers.MultiHeadAttention(
|
|
num_heads=NUM_HEAD,
|
|
key_dim=EMBED_DIM // NUM_HEAD,
|
|
)(query, key, value)
|
|
attention_output = keras.layers.Dropout(0.1)(attention_output)
|
|
attention_output = keras.layers.LayerNormalization(epsilon=1e-6)(
|
|
query + attention_output
|
|
)
|
|
|
|
ffn = keras.Sequential(
|
|
[
|
|
keras.layers.Dense(FF_DIM, activation="relu"),
|
|
keras.layers.Dense(EMBED_DIM),
|
|
],
|
|
)
|
|
ffn_output = ffn(attention_output)
|
|
ffn_output = keras.layers.Dropout(0.1)(ffn_output)
|
|
sequence_output = keras.layers.LayerNormalization(epsilon=1e-6)(
|
|
attention_output + ffn_output
|
|
)
|
|
return sequence_output
|
|
|
|
|
|
def get_pos_encoding_matrix(max_len, d_emb):
|
|
pos_enc = np.array(
|
|
[
|
|
[pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
|
|
if pos != 0
|
|
else np.zeros(d_emb)
|
|
for pos in range(max_len)
|
|
]
|
|
)
|
|
pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])
|
|
pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])
|
|
return pos_enc
|
|
|
|
|
|
loss_fn = keras.losses.CategoricalCrossentropy()
|
|
loss_tracker = keras.metrics.Mean(name="loss")
|
|
|
|
|
|
class MaskedLanguageModel(keras.Model):
|
|
def train_step(self, inputs):
|
|
if len(inputs) == 3:
|
|
features, labels, sample_weight = inputs
|
|
else:
|
|
features, labels = inputs
|
|
sample_weight = None
|
|
|
|
with tf.GradientTape() as tape:
|
|
predictions = self(features, training=True)
|
|
loss = loss_fn(labels, predictions, sample_weight=sample_weight)
|
|
|
|
trainable_vars = self.trainable_variables
|
|
gradients = tape.gradient(loss, trainable_vars)
|
|
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
|
|
loss_tracker.update_state(loss, sample_weight=sample_weight)
|
|
return {"loss": loss_tracker.result()}
|
|
|
|
@property
|
|
def metrics(self):
|
|
return [loss_tracker]
|
|
|
|
|
|
def get_model(
|
|
build=False, compile=False, jit_compile=False, include_preprocessing=True
|
|
):
|
|
if include_preprocessing:
|
|
inputs = keras.layers.Input((), dtype="string")
|
|
x = get_input_preprocessor()(inputs)
|
|
else:
|
|
inputs = keras.layers.Input((SEQUENCE_LENGTH,), dtype=tf.int64)
|
|
x = inputs
|
|
word_embeddings = keras.layers.Embedding(VOCAB_SIZE, EMBED_DIM)(x)
|
|
position_embeddings = keras.layers.Embedding(
|
|
input_dim=SEQUENCE_LENGTH,
|
|
output_dim=EMBED_DIM,
|
|
weights=[get_pos_encoding_matrix(SEQUENCE_LENGTH, EMBED_DIM)],
|
|
trainable=False,
|
|
)(tf.range(start=0, limit=SEQUENCE_LENGTH, delta=1))
|
|
embeddings = word_embeddings + position_embeddings
|
|
|
|
encoder_output = embeddings
|
|
for i in range(NUM_LAYERS):
|
|
encoder_output = bert_module(
|
|
encoder_output, encoder_output, encoder_output, i
|
|
)
|
|
|
|
mlm_output = keras.layers.Dense(
|
|
VOCAB_SIZE, name="mlm_cls", activation="softmax"
|
|
)(encoder_output)
|
|
model = MaskedLanguageModel(inputs, mlm_output)
|
|
|
|
if compile:
|
|
optimizer = keras.optimizers.Adam()
|
|
model.compile(optimizer=optimizer, jit_compile=jit_compile)
|
|
return model
|
|
|
|
|
|
def get_custom_objects():
|
|
return {
|
|
"MaskedLanguageModel": MaskedLanguageModel,
|
|
}
|