sport-text-classification-ball/solution.ipynb

217 lines
6.1 KiB
Plaintext
Raw Permalink Normal View History

2024-05-17 16:29:25 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import csv\n",
"from gensim.models import Word2Vec\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"train_documents = []\n",
"train_classes = []\n",
"\n",
"with open('train/train.tsv', 'r', encoding='utf-8') as file:\n",
" lines = file.readlines()\n",
" for line in lines:\n",
" elements = line.split('\\t')\n",
" train_classes.append(int(elements[0]))\n",
" train_documents.append(elements[1].lower())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = Word2Vec(sentences=[doc.split() for doc in train_documents], vector_size=100, window=5, min_count=1, workers=4)\n",
"model.save(\"word2vec.model\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def get_test_data(path):\n",
" with open(path, 'r', encoding='utf-8') as file:\n",
" test_data = []\n",
" lines = file.readlines()\n",
" for line in lines:\n",
" test_data.append(line.strip().split('\\t')[0])\n",
" return test_data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dev0_documents = [x.lower() for x in get_test_data('dev-0/in.tsv')]\n",
"dev0_classes = [int(x) for x in get_test_data('dev-0/expected.tsv')]\n",
"a_documents = [x.lower() for x in get_test_data('test-A/in.tsv')]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def document_to_word2vec(doc):\n",
" vector = np.zeros(model.vector_size)\n",
" words = doc.split()\n",
" words_present = 0\n",
" for word in words:\n",
" if word in model.wv:\n",
" word_vector = model.wv.get_vector(word)\n",
" vector += word_vector\n",
" words_present += 1\n",
" if words_present > 0:\n",
" vector = vector / words_present\n",
" return vector"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model_loaded = Word2Vec.load(\"word2vec.model\")\n",
"\n",
"train_documents_word2vec = [document_to_word2vec(doc) for doc in train_documents]\n",
"dev0_documents_word2vec = [document_to_word2vec(doc) for doc in dev0_documents]\n",
"a_documents_word2vec = [document_to_word2vec(doc) for doc in a_documents]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.gaussian_process.kernels import RBF\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"from sklearn.neural_network import MLPClassifier\n",
"from sklearn.svm import SVC\n",
"from sklearn.naive_bayes import GaussianNB\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy for classifier Linear SVM: 0.9745047688921497\n",
"Test accuracy for classifier Naive Bayes: 0.892516507703595\n",
"Test accuracy for classifier Random Forest: 0.960564930300807\n",
"Test accuracy for classifier QDA: 0.923881144534116\n"
]
}
],
"source": [
"names = [\n",
" \"Linear SVM\",\n",
" #\"Neural Net\",\n",
" \"Naive Bayes\",\n",
" \"Random Forest\",\n",
" \"QDA\"\n",
"]\n",
"\n",
"classifiers = [\n",
" MLPClassifier(alpha=1, max_iter=1000, random_state=42),\n",
" #SVC(gamma=2, C=1, random_state=42),\n",
" GaussianNB(),\n",
" RandomForestClassifier(),\n",
" QuadraticDiscriminantAnalysis()\n",
"]\n",
"\n",
"best_accuracy = 0\n",
"best_classifier_name = \"\"\n",
"\n",
"for name, clf in zip(names, classifiers):\n",
" clf = make_pipeline(StandardScaler(), clf)\n",
" clf.fit(train_documents_word2vec, train_classes)\n",
"\n",
" dev0_predictions = clf.predict(dev0_documents_word2vec)\n",
" a_predictions = clf.predict(a_documents_word2vec)\n",
" dev0_accuracy = accuracy_score(dev0_classes, dev0_predictions)\n",
" print(\"Test accuracy for classifier \" + name + \":\", dev0_accuracy)\n",
"\n",
" if dev0_accuracy > best_accuracy:\n",
" best_accuracy = dev0_accuracy\n",
" best_classifier_name = name\n",
" best_dev0_predictions = dev0_predictions\n",
" best_a_predictions = a_predictions"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/out.tsv', 'w+', newline='', encoding='utf-8') as file:\n",
" writer = csv.writer(file, delimiter='\\t')\n",
" for prediction in best_dev0_predictions:\n",
" writer.writerow([prediction])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"with open('test-A/out.tsv', 'w+', newline='', encoding='utf-8') as file:\n",
" writer = csv.writer(file, delimiter='\\t')\n",
" for prediction in best_a_predictions:\n",
" writer.writerow([prediction])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}