dl_projekt/fasttext ffn.ipynb
2024-06-02 18:03:42 +02:00

114 KiB

Fasttext --> Feedforward Neural Network

import pandas as pd
import numpy as np

train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")
valid = pd.read_csv("valid.csv")

train.loc[train["review_score"]==-1, "review_score"]=0
test.loc[test["review_score"]==-1, "review_score"]=0
valid.loc[valid["review_score"]==-1, "review_score"]=0

Wczytywanie pliku z wyuczonymi angielskimi embeddingami fasttext (https://fasttext.cc/docs/en/crawl-vectors.html):

from gensim.models.fasttext import load_facebook_model

vectors_file = "cc.en.300.bin"
fb_model = load_facebook_model(vectors_file)

fb_model.wv.get_sentence_vector("Good game")
array([ 1.78246237e-02, -1.06189400e-01, -2.44729444e-02, -2.21012142e-02,
        1.59936324e-02,  3.34691778e-02, -1.06833659e-01, -3.30949295e-03,
       -2.14542568e-04,  1.48508116e-03, -2.93094777e-02, -2.15188637e-02,
        3.55641954e-02,  2.25957129e-02,  2.12054476e-02,  4.47175875e-02,
       -5.20626409e-03,  5.60886189e-02,  1.39281712e-02,  3.23922932e-02,
        2.46591270e-02,  9.72066727e-03, -8.12041387e-03, -4.68677804e-02,
        3.57795991e-02, -3.31927240e-02,  1.93208049e-04, -2.41375212e-02,
        1.56584624e-02,  9.49325413e-02,  3.46111394e-02, -1.00905932e-02,
        4.39240132e-03,  2.39842106e-02,  1.37603339e-02, -4.40947488e-02,
        7.08558830e-03,  8.42494331e-03,  7.17797317e-03,  7.57907005e-03,
        2.67372504e-02, -3.01967952e-02, -4.28102091e-02,  2.58105062e-02,
        5.47624193e-02,  5.96127613e-03, -1.95893943e-02,  4.33259085e-02,
       -2.37351451e-02,  8.60271696e-03,  5.02952486e-02, -1.85346920e-02,
        4.15457152e-02,  1.24201048e-02, -8.80148821e-03,  3.67457941e-02,
        6.00036196e-02, -1.33206621e-02, -9.77441017e-03,  8.50583613e-03,
        7.39839450e-02,  1.42067904e-02, -5.09730168e-03, -2.48528309e-02,
       -3.16993035e-02, -1.11876670e-02, -3.96261504e-03,  7.88500253e-03,
        2.44743899e-02, -3.55262943e-02,  1.29078805e-01, -1.49423713e-02,
       -6.50351308e-03, -2.00819727e-02, -2.81756446e-02, -2.70206425e-02,
       -1.32226339e-02, -2.91594937e-02,  8.19074363e-03,  1.02669047e-03,
        7.59374257e-03, -3.04275062e-02,  1.17772142e-03, -8.05200636e-03,
        2.26024836e-02, -1.43838655e-02, -6.42339066e-02, -1.50536001e-03,
       -6.04793020e-02, -6.25816360e-02, -3.21857110e-02,  2.24062111e-02,
        1.76566597e-02,  2.77742837e-03, -7.06164539e-03, -2.49959864e-02,
       -2.42400132e-02, -3.27228429e-03,  1.67806149e-02,  1.08696334e-03,
       -8.86186119e-03,  1.75914317e-02,  2.39897612e-03, -1.10858642e-02,
        4.37321281e-03, -1.52284997e-02,  1.28000462e-02,  3.08077410e-02,
       -3.72859016e-02,  1.01946890e-02,  2.72647869e-02,  4.09879610e-02,
       -5.15891647e-04, -1.48480488e-02,  3.46730947e-02,  3.72706391e-02,
       -3.27323265e-02, -2.06248090e-02,  5.93318080e-04, -1.00421635e-02,
        5.68121858e-03,  3.50009021e-03,  5.35131479e-03, -4.51390296e-02,
       -2.47980915e-02,  1.45408430e-03,  2.12197360e-02,  1.54698789e-02,
       -8.19472875e-03, -1.68298408e-02,  6.03655400e-03, -2.55106628e-04,
        3.55045721e-02, -3.14546265e-02, -1.14588011e-02,  4.87726405e-02,
       -4.81645577e-02,  2.94467416e-02, -1.99259575e-02,  3.88172567e-02,
        2.61239633e-02,  2.75779348e-02, -1.46336835e-02,  1.05489865e-02,
        2.05413140e-02, -6.39167754e-03, -1.77114069e-01,  4.08548955e-03,
        1.85309928e-02,  1.64314881e-02, -4.22716849e-02, -4.01179083e-02,
       -5.73144108e-02,  2.58368440e-02, -5.67379827e-03,  1.25513151e-02,
       -9.11396556e-03,  3.03542819e-02,  5.28743342e-02,  2.66610682e-02,
        3.72167975e-02,  1.17919482e-02,  4.31168377e-02, -4.78969961e-02,
        4.30925563e-02, -1.19056599e-02,  2.69640237e-02,  7.23459641e-04,
       -4.54982556e-02,  1.49192475e-02, -8.30576941e-03, -3.83936167e-02,
       -1.41146081e-03, -4.87017911e-03, -2.61655133e-02, -1.40941609e-02,
       -8.20814539e-03, -4.31758128e-02,  6.44639693e-03, -2.87407413e-02,
        8.11430439e-03,  3.59113403e-02, -9.36233252e-03,  1.77651364e-02,
        4.21329997e-02,  1.47886062e-02,  2.88083218e-02, -6.43404340e-03,
        1.00347018e-02, -6.53128885e-03,  2.48425044e-02, -2.85518263e-02,
       -2.17140149e-02, -1.74901895e-02, -2.42289691e-03, -4.53775264e-02,
       -1.26110762e-02,  1.16701461e-02,  4.70486209e-02,  1.68762766e-02,
        1.47036910e-02,  8.21339190e-02,  7.89464638e-03,  8.35293531e-03,
       -5.81975766e-02, -8.22040439e-03,  1.47995166e-02,  1.86446775e-02,
        7.70702632e-03, -2.84116622e-03, -3.67019586e-02, -8.45937431e-03,
       -3.01996600e-02, -2.45970348e-03,  3.84704992e-02, -2.66787992e-03,
        1.92980431e-02,  2.03489363e-02,  1.53061170e-02,  1.06934924e-02,
       -2.05292553e-02, -1.28532052e-02, -2.04120239e-04, -4.27634232e-02,
       -8.85266811e-02, -2.37210002e-02,  5.72774466e-03, -7.96672516e-03,
       -3.51342559e-02, -1.15768928e-02, -1.49289630e-02,  2.04419065e-02,
        5.61199570e-03, -2.51714028e-02,  4.59584072e-02,  1.01568084e-02,
       -1.03132576e-02, -3.47840693e-03,  2.52889507e-02, -2.66612396e-02,
        4.42611836e-02, -1.85268242e-02,  4.57341075e-02, -4.04713787e-02,
        2.18044654e-01,  2.41456348e-02, -9.57456708e-04,  1.83114167e-02,
       -3.54787558e-02,  2.46407185e-02,  7.85478577e-02,  2.96208858e-02,
       -1.95737258e-02,  8.29286966e-03,  3.32850311e-03, -2.48379502e-02,
        5.21280952e-02, -2.68758386e-02, -8.33445042e-03,  3.39096524e-02,
       -7.42925378e-03,  7.67563283e-03, -7.80183300e-02,  7.10648810e-03,
       -1.09409131e-02,  8.09427444e-03, -4.79577575e-03, -7.35900737e-03,
       -8.02769959e-02, -2.83947289e-02,  2.03800239e-02, -1.31330648e-02,
        5.29023558e-02,  1.55006815e-02, -2.43931878e-02, -3.24712843e-02,
        6.48617418e-03,  3.87190990e-02,  1.33868558e-02, -2.17015427e-02,
        2.63994243e-02, -2.29329728e-02, -6.44776821e-02, -4.56387661e-02,
       -4.78457101e-03,  4.09317948e-03, -2.95639634e-02, -1.87856313e-02,
        1.11593939e-02, -1.78132392e-02, -5.34680905e-03,  1.92655828e-02,
       -9.44136158e-02,  3.33602726e-02, -1.28154957e-03, -2.71097990e-03,
        2.09544357e-02,  1.03261217e-01, -3.35561559e-02, -6.76982710e-03],
      dtype=float32)
from numpy.linalg import norm
def cosine_similarity(a, b):
    return np.dot(a,b)/(norm(a)*norm(b))

def get_sentence_similarity(sent_a, sent_b):
    vec1 = fb_model.wv.get_sentence_vector(sent_a)
    vec2 = fb_model.wv.get_sentence_vector(sent_b)
    return cosine_similarity(vec1,vec2)

print(get_sentence_similarity("Good game", "Amazing game, I love it!"))
print(get_sentence_similarity("Good game", "Horrible game. A buggy mess."))
0.9303907
0.94947904
train["vectorized"] = train["review_text"].apply(lambda x : fb_model.wv.get_sentence_vector(x))
train.iloc[0]["vectorized"]
array([ 8.71886499e-03, -9.33922902e-02, -3.61323059e-02, -2.44729072e-02,
        1.08409366e-02,  1.93068106e-02, -8.36923420e-02,  2.29316903e-03,
       -1.35135755e-03,  2.88301776e-03, -2.74441373e-02, -9.83558595e-03,
        3.87936011e-02,  9.47409403e-03,  3.30047533e-02,  4.45564575e-02,
        8.05891701e-04,  4.50393967e-02,  3.19698919e-03,  2.77824774e-02,
       -4.72737849e-03,  1.12508880e-02,  2.10752850e-03, -1.86015293e-02,
        2.95459442e-02, -1.79125778e-02, -9.69705731e-03, -1.25571457e-03,
        6.71224575e-03,  9.82025936e-02,  2.62354501e-02,  4.77575843e-04,
        4.30168724e-03,  2.30760872e-02,  7.46196136e-03, -2.56011821e-02,
        2.26500910e-03,  9.29989759e-03,  4.85312892e-03, -2.05451786e-03,
        1.66831594e-02, -4.36836779e-02, -3.61202359e-02,  1.47827193e-02,
        4.84717675e-02,  2.09881645e-02, -1.83120575e-02,  3.67828794e-02,
       -2.54551638e-02,  5.68488985e-03,  3.21363397e-02, -1.92501992e-02,
        3.51247750e-02,  6.23576110e-03, -7.48702977e-03,  2.46541500e-02,
        4.22804020e-02, -1.25615057e-02, -5.79859968e-03,  4.38177120e-03,
        6.56097755e-02,  4.94229654e-03,  8.99066159e-04, -1.51860202e-02,
       -1.56211155e-02, -1.14172837e-02,  2.29500444e-03,  7.24412594e-03,
        1.61466487e-02, -2.63501853e-02,  1.06835075e-01, -1.32511724e-02,
        1.00777773e-02, -1.87501553e-02, -1.71737131e-02, -1.66921876e-02,
       -2.48353798e-02, -3.03455871e-02, -7.35406764e-03,  1.86671335e-02,
        1.11391433e-02,  1.50522811e-03,  7.59978918e-03, -1.43874844e-03,
        7.25007243e-03, -8.30824673e-03, -4.84226309e-02,  7.33444653e-03,
       -2.85197571e-02, -4.60745730e-02, -2.38343086e-02,  1.97487529e-02,
        4.29472663e-02, -5.94962295e-03,  7.63149792e-03, -9.71615221e-03,
       -1.44147705e-02,  3.04668932e-03,  2.28872746e-02,  9.57457058e-04,
       -1.47111658e-02,  1.44942962e-02, -2.48183205e-04, -1.70479212e-02,
        9.38933901e-03, -2.18814029e-03,  2.38826945e-02,  2.83374488e-02,
       -4.21173871e-02,  1.50122121e-02,  2.29728036e-02,  2.54183225e-02,
       -8.07054807e-04, -1.51364841e-02,  3.09927687e-02,  3.13952193e-02,
       -1.21332435e-02, -1.32499263e-02, -9.75007843e-03, -5.41008521e-05,
       -9.06463992e-03, -1.11404026e-03, -6.34767395e-03, -2.95756236e-02,
       -1.97567306e-02, -2.34055445e-02,  1.63229201e-02,  3.15271169e-02,
       -2.27316860e-02, -1.17390202e-02, -7.85546657e-03, -1.27675962e-02,
        4.69210669e-02, -5.10264486e-02, -1.76178105e-03,  2.61727888e-02,
       -5.78924827e-02,  1.59178916e-02, -1.30226174e-02,  3.05363275e-02,
        2.62711458e-02,  2.42318660e-02,  1.94618828e-03,  1.08218072e-02,
        3.02217435e-02,  1.15953386e-02, -1.63421690e-01,  1.32743130e-02,
        9.48617421e-03, -2.56551709e-03, -4.67495248e-02, -3.61125320e-02,
       -4.41548303e-02,  2.86387391e-02, -1.81056850e-04, -2.06265738e-03,
        1.14647327e-02,  4.72528152e-02,  4.42985930e-02,  8.55786633e-03,
        3.78191061e-02,  1.83753986e-02,  4.15914468e-02, -4.11635116e-02,
        3.33526842e-02, -1.25481822e-02,  2.34596524e-02,  8.43344163e-03,
       -4.16795984e-02,  7.25348853e-03, -2.80407108e-02, -3.69634405e-02,
       -2.46312029e-05, -9.05072596e-03, -2.41407454e-02, -9.63037368e-03,
        7.61253759e-03, -2.81870961e-02, -5.51456306e-03, -2.98725273e-02,
        1.28271757e-03,  3.05276066e-02, -3.51019343e-03,  2.49049999e-02,
        4.08466980e-02,  1.64877940e-02,  3.08923759e-02,  8.13027727e-04,
        1.61136538e-02, -4.12218878e-03,  2.07552351e-02, -1.96164269e-02,
       -1.81689784e-02, -2.35741213e-02,  1.90784093e-02, -2.61228625e-02,
        5.52171562e-03,  3.33901797e-03,  3.08879260e-02,  1.22919949e-02,
        1.14197442e-02,  4.92207371e-02, -1.31729420e-03,  1.34627940e-02,
       -3.66763920e-02, -9.86301922e-04, -3.99103155e-03,  2.03184932e-02,
       -8.94649187e-04, -3.81594640e-03, -3.50187197e-02, -8.03904049e-03,
       -3.02518457e-02,  9.22963768e-03,  2.14075781e-02, -6.65908679e-03,
        8.25127028e-03,  2.21848190e-02,  2.46407872e-04,  6.03683246e-03,
       -2.76617929e-02, -1.00275213e-02, -1.12058101e-02, -4.11576666e-02,
       -8.37493241e-02, -2.64801420e-02, -1.94396731e-03, -2.57269316e-03,
       -2.63478048e-02,  4.14898433e-03, -2.23227218e-02,  1.55853424e-02,
       -4.74900706e-03, -1.03320079e-02,  6.25639558e-02, -5.76885510e-03,
        1.14211654e-02, -1.88096928e-05,  3.23872566e-02, -1.09968400e-02,
        3.72492447e-02, -1.76831800e-02,  3.80933434e-02, -3.50751691e-02,
        2.17064083e-01,  1.98509376e-02, -2.04760581e-02,  1.57990574e-03,
       -9.46057122e-03,  1.87625270e-02,  7.28069246e-02,  3.03323641e-02,
       -3.04877665e-02,  2.50301585e-02, -6.43385388e-03, -2.88080852e-02,
        4.21348773e-02, -1.85974501e-02, -4.92343074e-03,  3.06231380e-02,
       -3.25853354e-03,  3.45354341e-02, -6.65686503e-02, -3.49383074e-04,
       -3.34775564e-03,  3.12979519e-03, -2.10906472e-02, -8.10027681e-03,
       -6.81191757e-02, -3.96640748e-02,  3.50814611e-02, -2.37802081e-02,
        4.13298160e-02,  2.33794153e-02, -7.41491909e-04, -7.02932826e-04,
       -2.06176331e-03,  3.27063650e-02,  1.90925822e-02, -2.00635698e-02,
        1.49720861e-02, -4.70005861e-03, -6.46774769e-02, -3.90873700e-02,
       -3.91177554e-03,  1.29053893e-03, -2.50313692e-02, -3.08023542e-02,
       -4.39980626e-03, -1.87384859e-02,  1.05929542e-02,  1.05856510e-03,
       -8.73909295e-02,  2.03851666e-02, -8.28381535e-03,  1.07255662e-02,
        1.35344295e-02,  8.94187242e-02, -3.53905819e-02,  2.35818932e-03],
      dtype=float32)
test["vectorized"] = test["review_text"].apply(lambda x : fb_model.wv.get_sentence_vector(x))
valid["vectorized"] = valid["review_text"].apply(lambda x : fb_model.wv.get_sentence_vector(x))
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.optimizers import Adam

def create_model():
    inputs = keras.Input(shape=(300,))
    dense1 = layers.Dense(256, activation="relu")(inputs)
    dense2 = layers.Dense(128, activation="relu")(dense1)
    dense3 = layers.Dense(64, activation="relu")(dense2)
    output = layers.Dense(1, activation="sigmoid")(dense3)
    model = keras.Model(inputs=inputs, outputs=output)
    model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=1e-4), metrics=['accuracy'])
    return model
train_x = np.stack(train["vectorized"].values)
train_y = np.stack(train["review_score"].values)

valid_x = np.stack(valid["vectorized"].values)
valid_y = np.stack(valid["review_score"].values)

test_x = np.stack(test["vectorized"].values)
test_y = np.stack(test["review_score"].values)
callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=5, restore_best_weights=True)
model = create_model()
history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=25, callbacks=[callback])
Epoch 1/25
1351/1351 [==============================] - 3s 1ms/step - loss: 0.6815 - accuracy: 0.5627 - val_loss: 0.6221 - val_accuracy: 0.7079
Epoch 2/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6564 - accuracy: 0.6112 - val_loss: 0.6649 - val_accuracy: 0.5497
Epoch 3/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6492 - accuracy: 0.6204 - val_loss: 0.6619 - val_accuracy: 0.5553
Epoch 4/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6439 - accuracy: 0.6280 - val_loss: 0.6552 - val_accuracy: 0.5702
Epoch 5/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6405 - accuracy: 0.6333 - val_loss: 0.6135 - val_accuracy: 0.6346
Epoch 6/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6374 - accuracy: 0.6352 - val_loss: 0.7130 - val_accuracy: 0.4843
Epoch 7/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6348 - accuracy: 0.6409 - val_loss: 0.6140 - val_accuracy: 0.6315
Epoch 8/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6307 - accuracy: 0.6445 - val_loss: 0.5606 - val_accuracy: 0.7186
Epoch 9/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6297 - accuracy: 0.6449 - val_loss: 0.6899 - val_accuracy: 0.5398
Epoch 10/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6267 - accuracy: 0.6480 - val_loss: 0.6795 - val_accuracy: 0.5511
Epoch 11/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6251 - accuracy: 0.6486 - val_loss: 0.6778 - val_accuracy: 0.5430
Epoch 12/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6222 - accuracy: 0.6530 - val_loss: 0.5447 - val_accuracy: 0.7297
Epoch 13/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6206 - accuracy: 0.6541 - val_loss: 0.5630 - val_accuracy: 0.7080
Epoch 14/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6189 - accuracy: 0.6563 - val_loss: 0.5919 - val_accuracy: 0.6746
Epoch 15/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6168 - accuracy: 0.6588 - val_loss: 0.6887 - val_accuracy: 0.5485
Epoch 16/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6160 - accuracy: 0.6576 - val_loss: 0.5550 - val_accuracy: 0.7132
Epoch 17/25
1351/1351 [==============================] - 2s 1ms/step - loss: 0.6132 - accuracy: 0.6589 - val_loss: 0.6529 - val_accuracy: 0.5999
def create_model():
    inputs = keras.Input(shape=(300,))
    dense1 = layers.Dense(1024, activation="relu")(inputs)
    dense2 = layers.Dense(512, activation="relu")(dense1)
    dense3 = layers.Dense(128, activation="relu")(dense2)
    output = layers.Dense(1, activation="sigmoid")(dense3)
    model = keras.Model(inputs=inputs, outputs=output)
    model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=1e-4), metrics=['accuracy'])
    return model
callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=5, restore_best_weights=True)
model = create_model()
history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=25, callbacks=[callback])
Epoch 1/25
1351/1351 [==============================] - 3s 2ms/step - loss: 0.6707 - accuracy: 0.5788 - val_loss: 0.6569 - val_accuracy: 0.5608
Epoch 2/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6493 - accuracy: 0.6180 - val_loss: 0.8048 - val_accuracy: 0.3736
Epoch 3/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6426 - accuracy: 0.6281 - val_loss: 0.6285 - val_accuracy: 0.6160
Epoch 4/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6376 - accuracy: 0.6348 - val_loss: 0.6405 - val_accuracy: 0.5911
Epoch 5/25
1351/1351 [==============================] - 3s 2ms/step - loss: 0.6328 - accuracy: 0.6406 - val_loss: 0.6520 - val_accuracy: 0.5826
Epoch 6/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6283 - accuracy: 0.6458 - val_loss: 0.6654 - val_accuracy: 0.5612
Epoch 7/25
1351/1351 [==============================] - 3s 2ms/step - loss: 0.6251 - accuracy: 0.6484 - val_loss: 0.6713 - val_accuracy: 0.5577
Epoch 8/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6230 - accuracy: 0.6512 - val_loss: 0.5540 - val_accuracy: 0.7145
Epoch 9/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6208 - accuracy: 0.6529 - val_loss: 0.6870 - val_accuracy: 0.5254
Epoch 10/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6182 - accuracy: 0.6544 - val_loss: 0.5915 - val_accuracy: 0.6618
Epoch 11/25
1351/1351 [==============================] - 3s 2ms/step - loss: 0.6161 - accuracy: 0.6565 - val_loss: 0.6600 - val_accuracy: 0.5738
Epoch 12/25
1351/1351 [==============================] - 3s 2ms/step - loss: 0.6139 - accuracy: 0.6587 - val_loss: 0.7102 - val_accuracy: 0.5054
Epoch 13/25
1351/1351 [==============================] - 2s 2ms/step - loss: 0.6130 - accuracy: 0.6582 - val_loss: 0.5946 - val_accuracy: 0.6568
from matplotlib import pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Wartość funkcji straty')
plt.ylabel('Strata')
plt.xlabel('Epoka')
plt.legend(['train', 'valid'], loc='upper left')
<matplotlib.legend.Legend at 0x1b6a1ae2400>
from matplotlib import pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='upper left')
<matplotlib.legend.Legend at 0x1b6a1ab1b50>
model.save("fasttext_model.keras")

Pomimo kilku prób dostosowania parametrów model cechuje się stosunkowo niską jakością predykcji, pomimo tego że wygląda na przeuczony

(skacząca wartość accuracy/loss dla zbioru walidacyjnego)

import tensorflow as tf
def test_review_text(sentence):
    vectorized = fb_model.wv.get_sentence_vector(sentence)
    reshaped = tf.reshape(vectorized,shape=(1,300))
    #print(vectorized.shape)
    score = float(model(reshaped))
    score_rounded = round(score)
    print(score)
    if score_rounded==0:
        print("Negative review")
    else:
        print("Positive review")
test_review_text("A buggy, uninspired mess")
0.7820833921432495
Positive review
test_review_text("This game is bad")
0.44321733713150024
Negative review
test_review_text("This game destroyed my life")
0.8973167538642883
Positive review
test_review_text("Best game I've ever played")
0.8987871408462524
Positive review
test_review_text("Fun cooperative play with scalable difficulty. Rapid path to get into a game with friends or open public games. ")
0.5772996544837952
Positive review
test_review_text("Deliriously buggy. Fun if/when it works properly. Wait and see if they actually QA the next few patches before you play.")
0.6418458819389343
Positive review
test["model_predictions"] = model(np.stack(test["vectorized"].values))
test["model_predictions"] = test["model_predictions"].apply(lambda x : round(float(x)))
def get_metrics():
    df = test
    predictions = df["model_predictions"].to_numpy()
    true_values = df["review_score"].to_numpy()
    accuracy = np.sum(np.rint(predictions) == true_values)/len(true_values)
    TN_count = len(df.query("`review_score`==0 and `model_predictions`==0").index)
    TP_count = len(df.query("`review_score`==1 and `model_predictions`==1").index)
    FP_count = len(df.query("`review_score`==0 and `model_predictions`==1").index)
    FN_count = len(df.query("`review_score`==1 and `model_predictions`==0").index)
    precision = TP_count/(TP_count+FP_count)
    recall = TP_count/(TP_count+FN_count)
    F1_score = (2*precision*recall)/(precision+recall)
    print(f"Accuracy: {accuracy:.2f}")
    print(f"Precision: {precision:.2f}")
    print(f"Recall: {recall:.2f}")
    print(f"F1 Score: {F1_score:.2f}")
get_metrics()
Accuracy: 0.72
Precision: 0.89
Recall: 0.76
F1 Score: 0.82

Możliwe iż model osiągnąłby lepsze wyniki gdyby embeddingi były wyuczone na samych recenzjach gier, a nie pochodziły z gotowych wektorów dla języka angielskiego.