sport-text-classification-b.../sport text classification.ipynb
2021-05-25 22:38:13 +02:00

255 lines
7.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from gensim.test.utils import common_texts\n",
"from gensim.models import FastText\n",
"import os.path\n",
"import gzip\n",
"import shutil\n",
"import torch\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"features = 100\n",
"batch_size = 16\n",
"criterion = torch.nn.BCELoss()\n",
"\n",
"with gzip.open('train/train.tsv.gz', 'rb') as f_in:\n",
" with open('train/train.tsv', 'wb') as f_out:\n",
" shutil.copyfileobj(f_in, f_out)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 [mindaugas, budzinauskas, wierzy, w, odbudowę,...\n",
"1 [przyjmujący, reprezentacji, polski, wrócił, d...\n",
"2 [fen, 9:, zapowiedź, walki, róża, gumienna, vs...\n",
"3 [aleksander, filipiak:, czuję, się, dobrze, w,...\n",
"4 [victoria, carl, i, aleksiej, czerwotkin, mist...\n",
" ... \n",
"98127 [kamil, syprzak, zaczyna, kolekcjonować, trofe...\n",
"98128 [holandia:, dwa, gole, piotra, parzyszka, piot...\n",
"98129 [sparingowo:, korona, gorsza, od, stali., lett...\n",
"98130 [vive, -, wisła., ośmiu, debiutantów, w, tegor...\n",
"98131 [wta, miami:, timea, bacsinszky, pokonana,, sw...\n",
"Name: Text, Length: 98132, dtype: object"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('train/train.tsv', sep='\\t', names=[\"Ball\",\"Text\"])\n",
"data[\"Text\"] = data[\"Text\"].str.lower().str.split()\n",
"data[\"Text\"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ft_model = None\n",
"if not os.path.isfile('fasttext.model'):\n",
" ft_model = FastText(size=features, window=3, min_count=1)\n",
" ft_model.build_vocab(sentences=data[\"Text\"])\n",
" ft_model.train(data[\"Text\"], total_examples=len(data[\"Text\"]), epochs=10)\n",
" ft_model.save(\"fasttext.model\")\n",
"else:\n",
" ft_model = FastText.load(\"fasttext.model\")\n",
" \n",
"def document_vector(doc):\n",
" result = ft_model.wv[doc]\n",
" return np.max(result, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X = [document_vector(x) for x in data[\"Text\"]]\n",
"Y = data[\"Ball\"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class NeuralNetworkModel(torch.nn.Module):\n",
" def __init__(self):\n",
" super(NeuralNetworkModel, self).__init__()\n",
" self.fc1 = torch.nn.Linear(features,200)\n",
" self.fc2 = torch.nn.Linear(200,150)\n",
" self.fc3 = torch.nn.Linear(150,1)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = torch.relu(x)\n",
" x = self.fc2(x)\n",
" x = torch.sigmoid(x)\n",
" x = self.fc3(x)\n",
" x = torch.sigmoid(x)\n",
" return x\n",
"\n",
" \n",
"def get_loss_acc(model, X_dataset, Y_dataset):\n",
" loss_score = 0\n",
" acc_score = 0\n",
" items_total = 0\n",
" model.eval()\n",
" for i in range(0, Y_dataset.shape[0], batch_size):\n",
" x = X_dataset[i:i+batch_size]\n",
" x = torch.tensor(x)\n",
" y = Y_dataset[i:i+batch_size]\n",
" y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1,1)\n",
" y_predictions = model(x)\n",
" acc_score += torch.sum((y_predictions >= 0.5) == y).item()\n",
" items_total += y.shape[0] \n",
"\n",
" loss = criterion(y_predictions, y)\n",
"\n",
" loss_score += loss.item() * y.shape[0] \n",
" return (loss_score / items_total), (acc_score / items_total)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"model_path = 'nn.model'\n",
"nn_model = NeuralNetworkModel()\n",
" \n",
"if not os.path.isfile(model_path):\n",
" optimizer = optim.SGD(nn_model.parameters(), lr=0.1)\n",
"\n",
" display(get_loss_acc(nn_model, X, Y))\n",
" for epoch in range(5):\n",
" nn_model.train()\n",
" for i in range(0, len(X), batch_size):\n",
" x = X[i:i+batch_size]\n",
" x = torch.tensor(x)\n",
"\n",
" y = Y[i:i+batch_size]\n",
" y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1,1)\n",
"\n",
" y_predictions = nn_model(x)\n",
" loss = criterion(y_predictions, y)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" display(get_loss_acc(nn_model, X, Y))\n",
" torch.save(nn_model.state_dict(), model_path)\n",
"else:\n",
" nn_model.load_state_dict(torch.load(model_path))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"x_dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=[\"Text\"])[\"Text\"]\n",
"y_dev = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=[\"Ball\"])[\"Ball\"]\n",
"x_dev = [document_vector(x) for x in x_dev.str.lower().str.split()]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.45761072419184756, 0.7694424064563463)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_loss_acc(nn_model, x_dev, y_dev)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"y_dev_prediction = nn_model(torch.tensor(x_dev))\n",
"y_dev_prediction = np.array([round(y) for y in y_dev_prediction.flatten().tolist()])\n",
"np.savetxt(\"dev-0/out.tsv\", y_dev_prediction, fmt='%d')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"x_test = pd.read_csv('test-A/in.tsv', sep='\\t', names=[\"Text\"])[\"Text\"]\n",
"x_test = [document_vector(x) for x in x_test.str.lower().str.split()]\n",
"y_test_prediction = nn_model(torch.tensor(x_test))\n",
"y_test_prediction = np.array([round(y) for y in y_test_prediction.flatten().tolist()])\n",
"np.savetxt(\"test-A/out.tsv\", y_test_prediction, fmt='%d')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}