{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from gensim.test.utils import common_texts\n", "from gensim.models import FastText\n", "import os.path\n", "import gzip\n", "import shutil\n", "import torch\n", "import torch.optim as optim" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "features = 100\n", "batch_size = 16\n", "criterion = torch.nn.BCELoss()\n", "\n", "with gzip.open('train/train.tsv.gz', 'rb') as f_in:\n", " with open('train/train.tsv', 'wb') as f_out:\n", " shutil.copyfileobj(f_in, f_out)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 [mindaugas, budzinauskas, wierzy, w, odbudowę,...\n", "1 [przyjmujący, reprezentacji, polski, wrócił, d...\n", "2 [fen, 9:, zapowiedź, walki, róża, gumienna, vs...\n", "3 [aleksander, filipiak:, czuję, się, dobrze, w,...\n", "4 [victoria, carl, i, aleksiej, czerwotkin, mist...\n", " ... \n", "98127 [kamil, syprzak, zaczyna, kolekcjonować, trofe...\n", "98128 [holandia:, dwa, gole, piotra, parzyszka, piot...\n", "98129 [sparingowo:, korona, gorsza, od, stali., lett...\n", "98130 [vive, -, wisła., ośmiu, debiutantów, w, tegor...\n", "98131 [wta, miami:, timea, bacsinszky, pokonana,, sw...\n", "Name: Text, Length: 98132, dtype: object" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv('train/train.tsv', sep='\\t', names=[\"Ball\",\"Text\"])\n", "data[\"Text\"] = data[\"Text\"].str.lower().str.split()\n", "data[\"Text\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "ft_model = None\n", "if not os.path.isfile('fasttext.model'):\n", " ft_model = FastText(size=features, window=3, min_count=1)\n", " ft_model.build_vocab(sentences=data[\"Text\"])\n", " ft_model.train(data[\"Text\"], total_examples=len(data[\"Text\"]), epochs=10)\n", " ft_model.save(\"fasttext.model\")\n", "else:\n", " ft_model = FastText.load(\"fasttext.model\")\n", " \n", "def document_vector(doc):\n", " result = ft_model.wv[doc]\n", " return np.max(result, axis=0)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X = [document_vector(x) for x in data[\"Text\"]]\n", "Y = data[\"Ball\"]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class NeuralNetworkModel(torch.nn.Module):\n", " def __init__(self):\n", " super(NeuralNetworkModel, self).__init__()\n", " self.fc1 = torch.nn.Linear(features,200)\n", " self.fc2 = torch.nn.Linear(200,150)\n", " self.fc3 = torch.nn.Linear(150,1)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = torch.relu(x)\n", " x = self.fc2(x)\n", " x = torch.sigmoid(x)\n", " x = self.fc3(x)\n", " x = torch.sigmoid(x)\n", " return x\n", "\n", " \n", "def get_loss_acc(model, X_dataset, Y_dataset):\n", " loss_score = 0\n", " acc_score = 0\n", " items_total = 0\n", " model.eval()\n", " for i in range(0, Y_dataset.shape[0], batch_size):\n", " x = X_dataset[i:i+batch_size]\n", " x = torch.tensor(x)\n", " y = Y_dataset[i:i+batch_size]\n", " y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1,1)\n", " y_predictions = model(x)\n", " acc_score += torch.sum((y_predictions >= 0.5) == y).item()\n", " items_total += y.shape[0] \n", "\n", " loss = criterion(y_predictions, y)\n", "\n", " loss_score += loss.item() * y.shape[0] \n", " return (loss_score / items_total), (acc_score / items_total)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [], "source": [ "model_path = 'nn.model'\n", "nn_model = NeuralNetworkModel()\n", " \n", "if not os.path.isfile(model_path):\n", " optimizer = optim.SGD(nn_model.parameters(), lr=0.1)\n", "\n", " display(get_loss_acc(nn_model, X, Y))\n", " for epoch in range(5):\n", " nn_model.train()\n", " for i in range(0, len(X), batch_size):\n", " x = X[i:i+batch_size]\n", " x = torch.tensor(x)\n", "\n", " y = Y[i:i+batch_size]\n", " y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1,1)\n", "\n", " y_predictions = nn_model(x)\n", " loss = criterion(y_predictions, y)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " display(get_loss_acc(nn_model, X, Y))\n", " torch.save(nn_model.state_dict(), model_path)\n", "else:\n", " nn_model.load_state_dict(torch.load(model_path))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "x_dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=[\"Text\"])[\"Text\"]\n", "y_dev = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=[\"Ball\"])[\"Ball\"]\n", "x_dev = [document_vector(x) for x in x_dev.str.lower().str.split()]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.45761072419184756, 0.7694424064563463)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_loss_acc(nn_model, x_dev, y_dev)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "y_dev_prediction = nn_model(torch.tensor(x_dev))\n", "y_dev_prediction = np.array([round(y) for y in y_dev_prediction.flatten().tolist()])\n", "np.savetxt(\"dev-0/out.tsv\", y_dev_prediction, fmt='%d')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "x_test = pd.read_csv('test-A/in.tsv', sep='\\t', names=[\"Text\"])[\"Text\"]\n", "x_test = [document_vector(x) for x in x_test.str.lower().str.split()]\n", "y_test_prediction = nn_model(torch.tensor(x_test))\n", "y_test_prediction = np.array([round(y) for y in y_test_prediction.flatten().tolist()])\n", "np.savetxt(\"test-A/out.tsv\", y_test_prediction, fmt='%d')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }