114 KiB
114 KiB
Fasttext --> Feedforward Neural Network
import pandas as pd
import numpy as np
train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")
valid = pd.read_csv("valid.csv")
train.loc[train["review_score"]==-1, "review_score"]=0
test.loc[test["review_score"]==-1, "review_score"]=0
valid.loc[valid["review_score"]==-1, "review_score"]=0
Wczytywanie pliku z wyuczonymi angielskimi embeddingami fasttext (https://fasttext.cc/docs/en/crawl-vectors.html):
from gensim.models.fasttext import load_facebook_model
vectors_file = "cc.en.300.bin"
fb_model = load_facebook_model(vectors_file)
fb_model.wv.get_sentence_vector("Good game")
array([ 1.78246237e-02, -1.06189400e-01, -2.44729444e-02, -2.21012142e-02, 1.59936324e-02, 3.34691778e-02, -1.06833659e-01, -3.30949295e-03, -2.14542568e-04, 1.48508116e-03, -2.93094777e-02, -2.15188637e-02, 3.55641954e-02, 2.25957129e-02, 2.12054476e-02, 4.47175875e-02, -5.20626409e-03, 5.60886189e-02, 1.39281712e-02, 3.23922932e-02, 2.46591270e-02, 9.72066727e-03, -8.12041387e-03, -4.68677804e-02, 3.57795991e-02, -3.31927240e-02, 1.93208049e-04, -2.41375212e-02, 1.56584624e-02, 9.49325413e-02, 3.46111394e-02, -1.00905932e-02, 4.39240132e-03, 2.39842106e-02, 1.37603339e-02, -4.40947488e-02, 7.08558830e-03, 8.42494331e-03, 7.17797317e-03, 7.57907005e-03, 2.67372504e-02, -3.01967952e-02, -4.28102091e-02, 2.58105062e-02, 5.47624193e-02, 5.96127613e-03, -1.95893943e-02, 4.33259085e-02, -2.37351451e-02, 8.60271696e-03, 5.02952486e-02, -1.85346920e-02, 4.15457152e-02, 1.24201048e-02, -8.80148821e-03, 3.67457941e-02, 6.00036196e-02, -1.33206621e-02, -9.77441017e-03, 8.50583613e-03, 7.39839450e-02, 1.42067904e-02, -5.09730168e-03, -2.48528309e-02, -3.16993035e-02, -1.11876670e-02, -3.96261504e-03, 7.88500253e-03, 2.44743899e-02, -3.55262943e-02, 1.29078805e-01, -1.49423713e-02, -6.50351308e-03, -2.00819727e-02, -2.81756446e-02, -2.70206425e-02, -1.32226339e-02, -2.91594937e-02, 8.19074363e-03, 1.02669047e-03, 7.59374257e-03, -3.04275062e-02, 1.17772142e-03, -8.05200636e-03, 2.26024836e-02, -1.43838655e-02, -6.42339066e-02, -1.50536001e-03, -6.04793020e-02, -6.25816360e-02, -3.21857110e-02, 2.24062111e-02, 1.76566597e-02, 2.77742837e-03, -7.06164539e-03, -2.49959864e-02, -2.42400132e-02, -3.27228429e-03, 1.67806149e-02, 1.08696334e-03, -8.86186119e-03, 1.75914317e-02, 2.39897612e-03, -1.10858642e-02, 4.37321281e-03, -1.52284997e-02, 1.28000462e-02, 3.08077410e-02, -3.72859016e-02, 1.01946890e-02, 2.72647869e-02, 4.09879610e-02, -5.15891647e-04, -1.48480488e-02, 3.46730947e-02, 3.72706391e-02, -3.27323265e-02, -2.06248090e-02, 5.93318080e-04, -1.00421635e-02, 5.68121858e-03, 3.50009021e-03, 5.35131479e-03, -4.51390296e-02, -2.47980915e-02, 1.45408430e-03, 2.12197360e-02, 1.54698789e-02, -8.19472875e-03, -1.68298408e-02, 6.03655400e-03, -2.55106628e-04, 3.55045721e-02, -3.14546265e-02, -1.14588011e-02, 4.87726405e-02, -4.81645577e-02, 2.94467416e-02, -1.99259575e-02, 3.88172567e-02, 2.61239633e-02, 2.75779348e-02, -1.46336835e-02, 1.05489865e-02, 2.05413140e-02, -6.39167754e-03, -1.77114069e-01, 4.08548955e-03, 1.85309928e-02, 1.64314881e-02, -4.22716849e-02, -4.01179083e-02, -5.73144108e-02, 2.58368440e-02, -5.67379827e-03, 1.25513151e-02, -9.11396556e-03, 3.03542819e-02, 5.28743342e-02, 2.66610682e-02, 3.72167975e-02, 1.17919482e-02, 4.31168377e-02, -4.78969961e-02, 4.30925563e-02, -1.19056599e-02, 2.69640237e-02, 7.23459641e-04, -4.54982556e-02, 1.49192475e-02, -8.30576941e-03, -3.83936167e-02, -1.41146081e-03, -4.87017911e-03, -2.61655133e-02, -1.40941609e-02, -8.20814539e-03, -4.31758128e-02, 6.44639693e-03, -2.87407413e-02, 8.11430439e-03, 3.59113403e-02, -9.36233252e-03, 1.77651364e-02, 4.21329997e-02, 1.47886062e-02, 2.88083218e-02, -6.43404340e-03, 1.00347018e-02, -6.53128885e-03, 2.48425044e-02, -2.85518263e-02, -2.17140149e-02, -1.74901895e-02, -2.42289691e-03, -4.53775264e-02, -1.26110762e-02, 1.16701461e-02, 4.70486209e-02, 1.68762766e-02, 1.47036910e-02, 8.21339190e-02, 7.89464638e-03, 8.35293531e-03, -5.81975766e-02, -8.22040439e-03, 1.47995166e-02, 1.86446775e-02, 7.70702632e-03, -2.84116622e-03, -3.67019586e-02, -8.45937431e-03, -3.01996600e-02, -2.45970348e-03, 3.84704992e-02, -2.66787992e-03, 1.92980431e-02, 2.03489363e-02, 1.53061170e-02, 1.06934924e-02, -2.05292553e-02, -1.28532052e-02, -2.04120239e-04, -4.27634232e-02, -8.85266811e-02, -2.37210002e-02, 5.72774466e-03, -7.96672516e-03, -3.51342559e-02, -1.15768928e-02, -1.49289630e-02, 2.04419065e-02, 5.61199570e-03, -2.51714028e-02, 4.59584072e-02, 1.01568084e-02, -1.03132576e-02, -3.47840693e-03, 2.52889507e-02, -2.66612396e-02, 4.42611836e-02, -1.85268242e-02, 4.57341075e-02, -4.04713787e-02, 2.18044654e-01, 2.41456348e-02, -9.57456708e-04, 1.83114167e-02, -3.54787558e-02, 2.46407185e-02, 7.85478577e-02, 2.96208858e-02, -1.95737258e-02, 8.29286966e-03, 3.32850311e-03, -2.48379502e-02, 5.21280952e-02, -2.68758386e-02, -8.33445042e-03, 3.39096524e-02, -7.42925378e-03, 7.67563283e-03, -7.80183300e-02, 7.10648810e-03, -1.09409131e-02, 8.09427444e-03, -4.79577575e-03, -7.35900737e-03, -8.02769959e-02, -2.83947289e-02, 2.03800239e-02, -1.31330648e-02, 5.29023558e-02, 1.55006815e-02, -2.43931878e-02, -3.24712843e-02, 6.48617418e-03, 3.87190990e-02, 1.33868558e-02, -2.17015427e-02, 2.63994243e-02, -2.29329728e-02, -6.44776821e-02, -4.56387661e-02, -4.78457101e-03, 4.09317948e-03, -2.95639634e-02, -1.87856313e-02, 1.11593939e-02, -1.78132392e-02, -5.34680905e-03, 1.92655828e-02, -9.44136158e-02, 3.33602726e-02, -1.28154957e-03, -2.71097990e-03, 2.09544357e-02, 1.03261217e-01, -3.35561559e-02, -6.76982710e-03], dtype=float32)
from numpy.linalg import norm
def cosine_similarity(a, b):
return np.dot(a,b)/(norm(a)*norm(b))
def get_sentence_similarity(sent_a, sent_b):
vec1 = fb_model.wv.get_sentence_vector(sent_a)
vec2 = fb_model.wv.get_sentence_vector(sent_b)
return cosine_similarity(vec1,vec2)
print(get_sentence_similarity("Good game", "Amazing game, I love it!"))
print(get_sentence_similarity("Good game", "Horrible game. A buggy mess."))
0.9303907 0.94947904
train["vectorized"] = train["review_text"].apply(lambda x : fb_model.wv.get_sentence_vector(x))
train.iloc[0]["vectorized"]
array([ 8.71886499e-03, -9.33922902e-02, -3.61323059e-02, -2.44729072e-02, 1.08409366e-02, 1.93068106e-02, -8.36923420e-02, 2.29316903e-03, -1.35135755e-03, 2.88301776e-03, -2.74441373e-02, -9.83558595e-03, 3.87936011e-02, 9.47409403e-03, 3.30047533e-02, 4.45564575e-02, 8.05891701e-04, 4.50393967e-02, 3.19698919e-03, 2.77824774e-02, -4.72737849e-03, 1.12508880e-02, 2.10752850e-03, -1.86015293e-02, 2.95459442e-02, -1.79125778e-02, -9.69705731e-03, -1.25571457e-03, 6.71224575e-03, 9.82025936e-02, 2.62354501e-02, 4.77575843e-04, 4.30168724e-03, 2.30760872e-02, 7.46196136e-03, -2.56011821e-02, 2.26500910e-03, 9.29989759e-03, 4.85312892e-03, -2.05451786e-03, 1.66831594e-02, -4.36836779e-02, -3.61202359e-02, 1.47827193e-02, 4.84717675e-02, 2.09881645e-02, -1.83120575e-02, 3.67828794e-02, -2.54551638e-02, 5.68488985e-03, 3.21363397e-02, -1.92501992e-02, 3.51247750e-02, 6.23576110e-03, -7.48702977e-03, 2.46541500e-02, 4.22804020e-02, -1.25615057e-02, -5.79859968e-03, 4.38177120e-03, 6.56097755e-02, 4.94229654e-03, 8.99066159e-04, -1.51860202e-02, -1.56211155e-02, -1.14172837e-02, 2.29500444e-03, 7.24412594e-03, 1.61466487e-02, -2.63501853e-02, 1.06835075e-01, -1.32511724e-02, 1.00777773e-02, -1.87501553e-02, -1.71737131e-02, -1.66921876e-02, -2.48353798e-02, -3.03455871e-02, -7.35406764e-03, 1.86671335e-02, 1.11391433e-02, 1.50522811e-03, 7.59978918e-03, -1.43874844e-03, 7.25007243e-03, -8.30824673e-03, -4.84226309e-02, 7.33444653e-03, -2.85197571e-02, -4.60745730e-02, -2.38343086e-02, 1.97487529e-02, 4.29472663e-02, -5.94962295e-03, 7.63149792e-03, -9.71615221e-03, -1.44147705e-02, 3.04668932e-03, 2.28872746e-02, 9.57457058e-04, -1.47111658e-02, 1.44942962e-02, -2.48183205e-04, -1.70479212e-02, 9.38933901e-03, -2.18814029e-03, 2.38826945e-02, 2.83374488e-02, -4.21173871e-02, 1.50122121e-02, 2.29728036e-02, 2.54183225e-02, -8.07054807e-04, -1.51364841e-02, 3.09927687e-02, 3.13952193e-02, -1.21332435e-02, -1.32499263e-02, -9.75007843e-03, -5.41008521e-05, -9.06463992e-03, -1.11404026e-03, -6.34767395e-03, -2.95756236e-02, -1.97567306e-02, -2.34055445e-02, 1.63229201e-02, 3.15271169e-02, -2.27316860e-02, -1.17390202e-02, -7.85546657e-03, -1.27675962e-02, 4.69210669e-02, -5.10264486e-02, -1.76178105e-03, 2.61727888e-02, -5.78924827e-02, 1.59178916e-02, -1.30226174e-02, 3.05363275e-02, 2.62711458e-02, 2.42318660e-02, 1.94618828e-03, 1.08218072e-02, 3.02217435e-02, 1.15953386e-02, -1.63421690e-01, 1.32743130e-02, 9.48617421e-03, -2.56551709e-03, -4.67495248e-02, -3.61125320e-02, -4.41548303e-02, 2.86387391e-02, -1.81056850e-04, -2.06265738e-03, 1.14647327e-02, 4.72528152e-02, 4.42985930e-02, 8.55786633e-03, 3.78191061e-02, 1.83753986e-02, 4.15914468e-02, -4.11635116e-02, 3.33526842e-02, -1.25481822e-02, 2.34596524e-02, 8.43344163e-03, -4.16795984e-02, 7.25348853e-03, -2.80407108e-02, -3.69634405e-02, -2.46312029e-05, -9.05072596e-03, -2.41407454e-02, -9.63037368e-03, 7.61253759e-03, -2.81870961e-02, -5.51456306e-03, -2.98725273e-02, 1.28271757e-03, 3.05276066e-02, -3.51019343e-03, 2.49049999e-02, 4.08466980e-02, 1.64877940e-02, 3.08923759e-02, 8.13027727e-04, 1.61136538e-02, -4.12218878e-03, 2.07552351e-02, -1.96164269e-02, -1.81689784e-02, -2.35741213e-02, 1.90784093e-02, -2.61228625e-02, 5.52171562e-03, 3.33901797e-03, 3.08879260e-02, 1.22919949e-02, 1.14197442e-02, 4.92207371e-02, -1.31729420e-03, 1.34627940e-02, -3.66763920e-02, -9.86301922e-04, -3.99103155e-03, 2.03184932e-02, -8.94649187e-04, -3.81594640e-03, -3.50187197e-02, -8.03904049e-03, -3.02518457e-02, 9.22963768e-03, 2.14075781e-02, -6.65908679e-03, 8.25127028e-03, 2.21848190e-02, 2.46407872e-04, 6.03683246e-03, -2.76617929e-02, -1.00275213e-02, -1.12058101e-02, -4.11576666e-02, -8.37493241e-02, -2.64801420e-02, -1.94396731e-03, -2.57269316e-03, -2.63478048e-02, 4.14898433e-03, -2.23227218e-02, 1.55853424e-02, -4.74900706e-03, -1.03320079e-02, 6.25639558e-02, -5.76885510e-03, 1.14211654e-02, -1.88096928e-05, 3.23872566e-02, -1.09968400e-02, 3.72492447e-02, -1.76831800e-02, 3.80933434e-02, -3.50751691e-02, 2.17064083e-01, 1.98509376e-02, -2.04760581e-02, 1.57990574e-03, -9.46057122e-03, 1.87625270e-02, 7.28069246e-02, 3.03323641e-02, -3.04877665e-02, 2.50301585e-02, -6.43385388e-03, -2.88080852e-02, 4.21348773e-02, -1.85974501e-02, -4.92343074e-03, 3.06231380e-02, -3.25853354e-03, 3.45354341e-02, -6.65686503e-02, -3.49383074e-04, -3.34775564e-03, 3.12979519e-03, -2.10906472e-02, -8.10027681e-03, -6.81191757e-02, -3.96640748e-02, 3.50814611e-02, -2.37802081e-02, 4.13298160e-02, 2.33794153e-02, -7.41491909e-04, -7.02932826e-04, -2.06176331e-03, 3.27063650e-02, 1.90925822e-02, -2.00635698e-02, 1.49720861e-02, -4.70005861e-03, -6.46774769e-02, -3.90873700e-02, -3.91177554e-03, 1.29053893e-03, -2.50313692e-02, -3.08023542e-02, -4.39980626e-03, -1.87384859e-02, 1.05929542e-02, 1.05856510e-03, -8.73909295e-02, 2.03851666e-02, -8.28381535e-03, 1.07255662e-02, 1.35344295e-02, 8.94187242e-02, -3.53905819e-02, 2.35818932e-03], dtype=float32)
test["vectorized"] = test["review_text"].apply(lambda x : fb_model.wv.get_sentence_vector(x))
valid["vectorized"] = valid["review_text"].apply(lambda x : fb_model.wv.get_sentence_vector(x))
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.optimizers import Adam
def create_model():
inputs = keras.Input(shape=(300,))
dense1 = layers.Dense(256, activation="relu")(inputs)
dense2 = layers.Dense(128, activation="relu")(dense1)
dense3 = layers.Dense(64, activation="relu")(dense2)
output = layers.Dense(1, activation="sigmoid")(dense3)
model = keras.Model(inputs=inputs, outputs=output)
model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=1e-4), metrics=['accuracy'])
return model
train_x = np.stack(train["vectorized"].values)
train_y = np.stack(train["review_score"].values)
valid_x = np.stack(valid["vectorized"].values)
valid_y = np.stack(valid["review_score"].values)
test_x = np.stack(test["vectorized"].values)
test_y = np.stack(test["review_score"].values)
callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=5, restore_best_weights=True)
model = create_model()
history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=25, callbacks=[callback])
Epoch 1/25 1351/1351 [==============================] - 3s 1ms/step - loss: 0.6815 - accuracy: 0.5627 - val_loss: 0.6221 - val_accuracy: 0.7079 Epoch 2/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6564 - accuracy: 0.6112 - val_loss: 0.6649 - val_accuracy: 0.5497 Epoch 3/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6492 - accuracy: 0.6204 - val_loss: 0.6619 - val_accuracy: 0.5553 Epoch 4/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6439 - accuracy: 0.6280 - val_loss: 0.6552 - val_accuracy: 0.5702 Epoch 5/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6405 - accuracy: 0.6333 - val_loss: 0.6135 - val_accuracy: 0.6346 Epoch 6/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6374 - accuracy: 0.6352 - val_loss: 0.7130 - val_accuracy: 0.4843 Epoch 7/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6348 - accuracy: 0.6409 - val_loss: 0.6140 - val_accuracy: 0.6315 Epoch 8/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6307 - accuracy: 0.6445 - val_loss: 0.5606 - val_accuracy: 0.7186 Epoch 9/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6297 - accuracy: 0.6449 - val_loss: 0.6899 - val_accuracy: 0.5398 Epoch 10/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6267 - accuracy: 0.6480 - val_loss: 0.6795 - val_accuracy: 0.5511 Epoch 11/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6251 - accuracy: 0.6486 - val_loss: 0.6778 - val_accuracy: 0.5430 Epoch 12/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6222 - accuracy: 0.6530 - val_loss: 0.5447 - val_accuracy: 0.7297 Epoch 13/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6206 - accuracy: 0.6541 - val_loss: 0.5630 - val_accuracy: 0.7080 Epoch 14/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6189 - accuracy: 0.6563 - val_loss: 0.5919 - val_accuracy: 0.6746 Epoch 15/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6168 - accuracy: 0.6588 - val_loss: 0.6887 - val_accuracy: 0.5485 Epoch 16/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6160 - accuracy: 0.6576 - val_loss: 0.5550 - val_accuracy: 0.7132 Epoch 17/25 1351/1351 [==============================] - 2s 1ms/step - loss: 0.6132 - accuracy: 0.6589 - val_loss: 0.6529 - val_accuracy: 0.5999
def create_model():
inputs = keras.Input(shape=(300,))
dense1 = layers.Dense(1024, activation="relu")(inputs)
dense2 = layers.Dense(512, activation="relu")(dense1)
dense3 = layers.Dense(128, activation="relu")(dense2)
output = layers.Dense(1, activation="sigmoid")(dense3)
model = keras.Model(inputs=inputs, outputs=output)
model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=1e-4), metrics=['accuracy'])
return model
callback = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=5, restore_best_weights=True)
model = create_model()
history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=25, callbacks=[callback])
Epoch 1/25 1351/1351 [==============================] - 3s 2ms/step - loss: 0.6707 - accuracy: 0.5788 - val_loss: 0.6569 - val_accuracy: 0.5608 Epoch 2/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6493 - accuracy: 0.6180 - val_loss: 0.8048 - val_accuracy: 0.3736 Epoch 3/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6426 - accuracy: 0.6281 - val_loss: 0.6285 - val_accuracy: 0.6160 Epoch 4/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6376 - accuracy: 0.6348 - val_loss: 0.6405 - val_accuracy: 0.5911 Epoch 5/25 1351/1351 [==============================] - 3s 2ms/step - loss: 0.6328 - accuracy: 0.6406 - val_loss: 0.6520 - val_accuracy: 0.5826 Epoch 6/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6283 - accuracy: 0.6458 - val_loss: 0.6654 - val_accuracy: 0.5612 Epoch 7/25 1351/1351 [==============================] - 3s 2ms/step - loss: 0.6251 - accuracy: 0.6484 - val_loss: 0.6713 - val_accuracy: 0.5577 Epoch 8/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6230 - accuracy: 0.6512 - val_loss: 0.5540 - val_accuracy: 0.7145 Epoch 9/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6208 - accuracy: 0.6529 - val_loss: 0.6870 - val_accuracy: 0.5254 Epoch 10/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6182 - accuracy: 0.6544 - val_loss: 0.5915 - val_accuracy: 0.6618 Epoch 11/25 1351/1351 [==============================] - 3s 2ms/step - loss: 0.6161 - accuracy: 0.6565 - val_loss: 0.6600 - val_accuracy: 0.5738 Epoch 12/25 1351/1351 [==============================] - 3s 2ms/step - loss: 0.6139 - accuracy: 0.6587 - val_loss: 0.7102 - val_accuracy: 0.5054 Epoch 13/25 1351/1351 [==============================] - 2s 2ms/step - loss: 0.6130 - accuracy: 0.6582 - val_loss: 0.5946 - val_accuracy: 0.6568
from matplotlib import pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Wartość funkcji straty')
plt.ylabel('Strata')
plt.xlabel('Epoka')
plt.legend(['train', 'valid'], loc='upper left')
<matplotlib.legend.Legend at 0x1b6a1ae2400>
from matplotlib import pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='upper left')
<matplotlib.legend.Legend at 0x1b6a1ab1b50>
model.save("fasttext_model.keras")
Pomimo kilku prób dostosowania parametrów model cechuje się stosunkowo niską jakością predykcji, pomimo tego że wygląda na przeuczony
(skacząca wartość accuracy/loss dla zbioru walidacyjnego)
import tensorflow as tf
def test_review_text(sentence):
vectorized = fb_model.wv.get_sentence_vector(sentence)
reshaped = tf.reshape(vectorized,shape=(1,300))
#print(vectorized.shape)
score = float(model(reshaped))
score_rounded = round(score)
print(score)
if score_rounded==0:
print("Negative review")
else:
print("Positive review")
test_review_text("A buggy, uninspired mess")
0.7820833921432495 Positive review
test_review_text("This game is bad")
0.44321733713150024 Negative review
test_review_text("This game destroyed my life")
0.8973167538642883 Positive review
test_review_text("Best game I've ever played")
0.8987871408462524 Positive review
test_review_text("Fun cooperative play with scalable difficulty. Rapid path to get into a game with friends or open public games. ")
0.5772996544837952 Positive review
test_review_text("Deliriously buggy. Fun if/when it works properly. Wait and see if they actually QA the next few patches before you play.")
0.6418458819389343 Positive review
test["model_predictions"] = model(np.stack(test["vectorized"].values))
test["model_predictions"] = test["model_predictions"].apply(lambda x : round(float(x)))
def get_metrics():
df = test
predictions = df["model_predictions"].to_numpy()
true_values = df["review_score"].to_numpy()
accuracy = np.sum(np.rint(predictions) == true_values)/len(true_values)
TN_count = len(df.query("`review_score`==0 and `model_predictions`==0").index)
TP_count = len(df.query("`review_score`==1 and `model_predictions`==1").index)
FP_count = len(df.query("`review_score`==0 and `model_predictions`==1").index)
FN_count = len(df.query("`review_score`==1 and `model_predictions`==0").index)
precision = TP_count/(TP_count+FP_count)
recall = TP_count/(TP_count+FN_count)
F1_score = (2*precision*recall)/(precision+recall)
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {F1_score:.2f}")
get_metrics()
Accuracy: 0.72 Precision: 0.89 Recall: 0.76 F1 Score: 0.82
Możliwe iż model osiągnąłby lepsze wyniki gdyby embeddingi były wyuczone na samych recenzjach gier, a nie pochodziły z gotowych wektorów dla języka angielskiego.