uczenie-glebokie-projekt/Projekt.ipynb

402 lines
12 KiB
Plaintext
Raw Permalink Normal View History

2024-06-10 06:37:29 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "LHtKZx0myNWa"
},
"source": [
"### Import bibliotek"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "ZTlYCCtCyNWc"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.svm import SVC\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.pipeline import Pipeline\n",
"from gensim.models import Word2Vec\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"import numpy as np\n",
"import re"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v16vUmROyNWc"
},
"source": [
"### Przygotowanie danych"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def get_str_cleaned(str_dirty):\n",
" punctuation = '!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~'\n",
" new_str = str_dirty.lower()\n",
" new_str = re.sub(' +', ' ', new_str)\n",
" for char in punctuation:\n",
" new_str = new_str.replace(char, '')\n",
" return new_str"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d4Kuyx7JyNWd",
"outputId": "0c9de8ef-4e90-44fd-9af4-d5e5833994aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" review sentiment\n",
"0 One of the other reviewers has mentioned that ... positive\n",
"1 A wonderful little production. <br /><br />The... positive\n",
"2 I thought this was a wonderful way to spend ti... positive\n",
"3 Basically there's a family where a little boy ... negative\n",
"4 Petter Mattei's \"Love in the Time of Money\" is... positive\n",
" review sentiment \\\n",
"0 One of the other reviewers has mentioned that ... 1 \n",
"1 A wonderful little production. <br /><br />The... 1 \n",
"2 I thought this was a wonderful way to spend ti... 1 \n",
"3 Basically there's a family where a little boy ... 0 \n",
"4 Petter Mattei's \"Love in the Time of Money\" is... 1 \n",
"\n",
" cleaned_review \n",
"0 one of the other reviewers has mentioned that ... \n",
"1 a wonderful little production br br the filmin... \n",
"2 i thought this was a wonderful way to spend ti... \n",
"3 basically theres a family where a little boy j... \n",
"4 petter matteis love in the time of money is a ... \n"
]
}
],
"source": [
"# Source: https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews\n",
"data = pd.read_csv('IMDB_reviews.csv')\n",
"print(data.head())\n",
"\n",
"# Czyszczenie danych\n",
"data['cleaned_review'] = data['review'].apply(get_str_cleaned)\n",
"\n",
"# Przekształcenie etykiet na format numeryczny\n",
"label_encoder = LabelEncoder()\n",
"data['sentiment'] = label_encoder.fit_transform(data['sentiment'])\n",
"\n",
"print(data.head())\n",
"\n",
"# Podział danych na zbiór treningowy i testowy\n",
"X = data['cleaned_review']\n",
"y = data['sentiment']\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Lz-Y4ZCyNWd"
},
"source": [
"### TF-IDF + SVM"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "ES_5Q4BEyNWd"
},
"outputs": [],
"source": [
"tfidf_svm_pipeline = Pipeline([\n",
" ('tfidf', TfidfVectorizer(max_features=200)),\n",
" ('svm', SVC(kernel='linear'))\n",
"])\n",
"tfidf_svm_pipeline.fit(X_train, y_train)\n",
"y_pred_tfidf_svm = tfidf_svm_pipeline.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fadLd3cEyNWd"
},
"source": [
"### TF-IDF + RandomForest"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "xUq30-FryNWe"
},
"outputs": [],
"source": [
"tfidf_rf_pipeline = Pipeline([\n",
" ('tfidf', TfidfVectorizer(max_features=200)),\n",
" ('rf', RandomForestClassifier(n_estimators=100))\n",
"])\n",
"tfidf_rf_pipeline.fit(X_train, y_train)\n",
"y_pred_tfidf_rf = tfidf_rf_pipeline.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d08OJrCnyNWe"
},
"source": [
"### Model Word2Vec i transformator dokumentów do postaci wektorowej"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "J5agaWJFyNWe"
},
"outputs": [],
"source": [
"w2v_model = Word2Vec(sentences=[doc.split() for doc in X_train], vector_size=200, window=5, min_count=5, workers=4)\n",
"class Word2VecTransformer(BaseEstimator, TransformerMixin):\n",
" def __init__(self, w2v_model):\n",
" self.w2v_model = w2v_model\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X):\n",
" return np.array([\n",
" np.mean([self.w2v_model.wv[word] for word in doc.split() if word in self.w2v_model.wv]\n",
" or [np.zeros(self.w2v_model.vector_size)], axis=0)\n",
" for doc in X\n",
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbKeeZBdyNWe"
},
"source": [
"### Word2Vec + SVM"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "FPBL7g75yNWe"
},
"outputs": [],
"source": [
"w2v_svm_pipeline = Pipeline([\n",
" ('w2v_transform', Word2VecTransformer(w2v_model)),\n",
" ('svm', SVC(kernel='linear'))\n",
"])\n",
"w2v_svm_pipeline.fit(X_train, y_train)\n",
"y_pred_w2v_svm = w2v_svm_pipeline.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KT-Cnwx7yNWe"
},
"source": [
"### Word2Vec + RandomForest"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "t9mCasDmyNWe"
},
"outputs": [],
"source": [
"w2v_rf_pipeline = Pipeline([\n",
" ('w2v_transform', Word2VecTransformer(w2v_model)),\n",
" ('rf', RandomForestClassifier(n_estimators=100))\n",
"])\n",
"w2v_rf_pipeline.fit(X_train, y_train)\n",
"y_pred_w2v_rf = w2v_rf_pipeline.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lkFzZ1MjyNWf"
},
"source": [
"### Wyświetlanie metryk"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" # Funkcja zwraca trafność, precyzję, pokrycie i F1\n",
" acc_score = 0\n",
" acc_total = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
"\n",
" for p, t in zip(y_pred, y_true):\n",
" acc_total += 1\n",
"\n",
" if p == t:\n",
" acc_score += 1\n",
"\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" accuracy = acc_score / acc_total\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return accuracy, precision, recall, f1"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"id": "W8RJkm0CyNWf"
},
"outputs": [],
"source": [
"def print_metrics(y_true, y_pred, model_name):\n",
" accuracy, precision, recall, f1 = get_scores(y_true, y_pred)\n",
" print(f'{model_name} Accuracy: {accuracy:.4f}')\n",
" print(f'{model_name} Precision: {precision:.4f}')\n",
" print(f'{model_name} Recall: {recall:.4f}')\n",
" print(f'{model_name} F1-Score: {f1:.4f}')\n",
" print('-' * 30)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tPPkR8MOyNWf",
"outputId": "ceae2217-10b0-4533-9f43-3c7add2d19b4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TF-IDF + SVM Accuracy: 0.7764\n",
"TF-IDF + SVM Precision: 0.7719\n",
"TF-IDF + SVM Recall: 0.7896\n",
"TF-IDF + SVM F1-Score: 0.7807\n",
"------------------------------\n",
"TF-IDF + Random Forest Accuracy: 0.7500\n",
"TF-IDF + Random Forest Precision: 0.7626\n",
"TF-IDF + Random Forest Recall: 0.7317\n",
"TF-IDF + Random Forest F1-Score: 0.7468\n",
"------------------------------\n",
"Word2Vec + SVM Accuracy: 0.8584\n",
"Word2Vec + SVM Precision: 0.8522\n",
"Word2Vec + SVM Recall: 0.8698\n",
"Word2Vec + SVM F1-Score: 0.8609\n",
"------------------------------\n",
"Word2Vec + Random Forest Accuracy: 0.8137\n",
"Word2Vec + Random Forest Precision: 0.8106\n",
"Word2Vec + Random Forest Recall: 0.8224\n",
"Word2Vec + Random Forest F1-Score: 0.8165\n",
"------------------------------\n"
]
}
],
"source": [
"# Ocena modelu TF-IDF + SVM\n",
"print_metrics(y_test, y_pred_tfidf_svm, 'TF-IDF + SVM')\n",
"\n",
"# Ocena modelu TF-IDF + Random Forest\n",
"print_metrics(y_test, y_pred_tfidf_rf, 'TF-IDF + Random Forest')\n",
"\n",
"# Ocena modelu Word2Vec + SVM\n",
"print_metrics(y_test, y_pred_w2v_svm, 'Word2Vec + SVM')\n",
"\n",
"# Ocena modelu Word2Vec + Random Forest\n",
"print_metrics(y_test, y_pred_w2v_rf, 'Word2Vec + Random Forest')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}