2021-05-06 22:50:47 +02:00
|
|
|
import re
|
|
|
|
import string
|
|
|
|
from silence_tensorflow import silence_tensorflow
|
|
|
|
silence_tensorflow()
|
2021-04-25 22:14:32 +02:00
|
|
|
import tensorflow as tf
|
2021-05-06 22:50:47 +02:00
|
|
|
from tensorflow.keras import layers
|
|
|
|
from tensorflow.keras import losses
|
|
|
|
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
|
|
|
|
|
|
|
|
|
|
|
|
def vectorize_text(text, label):
|
|
|
|
text = tf.expand_dims(text, -1)
|
|
|
|
return vectorize_layer(text), label
|
|
|
|
|
|
|
|
|
|
|
|
def custom_standardization(input_data):
|
|
|
|
lowercase = tf.strings.lower(input_data)
|
|
|
|
stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
|
|
|
|
return tf.strings.regex_replace(stripped_html, '[%s]' % re.escape(string.punctuation), '')
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = 32
|
|
|
|
seed = 42
|
|
|
|
|
|
|
|
raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
|
|
|
|
'aclImdb/train',
|
|
|
|
batch_size=batch_size,
|
|
|
|
validation_split=0.2,
|
|
|
|
subset='training',
|
|
|
|
seed=seed)
|
|
|
|
|
|
|
|
raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(
|
|
|
|
'aclImdb/train',
|
|
|
|
batch_size=batch_size,
|
|
|
|
validation_split=0.2,
|
|
|
|
subset='validation',
|
|
|
|
seed=seed)
|
|
|
|
|
|
|
|
raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(
|
|
|
|
'aclImdb/test',
|
|
|
|
batch_size=batch_size)
|
|
|
|
|
|
|
|
max_features = 10000
|
|
|
|
sequence_length = 250
|
|
|
|
|
|
|
|
vectorize_layer = TextVectorization(
|
|
|
|
standardize=custom_standardization,
|
|
|
|
max_tokens=max_features,
|
|
|
|
output_mode='int',
|
|
|
|
output_sequence_length=sequence_length)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
train_text = raw_train_ds.map(lambda x, y: x)
|
|
|
|
vectorize_layer.adapt(train_text)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
train_ds = raw_train_ds.map(vectorize_text)
|
|
|
|
val_ds = raw_val_ds.map(vectorize_text)
|
|
|
|
test_ds = raw_test_ds.map(vectorize_text)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
AUTOTUNE = tf.data.AUTOTUNE
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
|
|
|
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
|
|
|
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
embedding_dim = 16
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
model = tf.keras.Sequential([
|
|
|
|
layers.Embedding(max_features + 1, embedding_dim),
|
|
|
|
layers.Dropout(0.2),
|
|
|
|
layers.GlobalAveragePooling1D(),
|
|
|
|
layers.Dropout(0.2),
|
|
|
|
layers.Dense(1)])
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
model.summary()
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
model.compile(loss=losses.BinaryCrossentropy(from_logits=True),
|
|
|
|
optimizer='adam',
|
|
|
|
metrics=tf.metrics.BinaryAccuracy(threshold=0.0))
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
epochs = 10
|
|
|
|
history = model.fit(
|
|
|
|
train_ds,
|
|
|
|
validation_data=val_ds,
|
|
|
|
epochs=epochs)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
loss, accuracy = model.evaluate(test_ds)
|
|
|
|
print("Loss: ", loss)
|
|
|
|
print("Accuracy: ", accuracy)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
export_model = tf.keras.Sequential([
|
|
|
|
vectorize_layer,
|
|
|
|
model,
|
|
|
|
layers.Activation('sigmoid')
|
|
|
|
])
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
export_model.compile(
|
|
|
|
loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy']
|
|
|
|
)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
loss, accuracy = export_model.evaluate(raw_test_ds)
|
|
|
|
print("Loss: ", loss)
|
|
|
|
print("Accuracy: ", accuracy)
|
2021-04-25 22:14:32 +02:00
|
|
|
|
2021-05-06 22:50:47 +02:00
|
|
|
file = open('results.txt', 'w')
|
|
|
|
file.write('test loss: ' + loss + '\n' + 'test accuracy: ' + accuracy)
|
|
|
|
file.close()
|
2021-04-25 22:14:32 +02:00
|
|
|
|