skończony projekt
This commit is contained in:
parent
ca9cd56b86
commit
c824c34f8c
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,3 +1,5 @@
|
|||||||
train/train.tsv
|
train/train.tsv
|
||||||
.ipynb_checkpoints*
|
.ipynb_checkpoints/*
|
||||||
word2vec.model
|
fasttext.model*
|
||||||
|
nn.model
|
||||||
|
geval
|
@ -1,204 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 19,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: gensim in c:\\users\\annad\\anaconda3\\lib\\site-packages (3.8.3)\n",
|
|
||||||
"Requirement already satisfied: smart-open>=1.8.1 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from gensim) (5.0.0)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5.0 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from gensim) (1.15.0)\n",
|
|
||||||
"Requirement already satisfied: numpy>=1.11.3 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from gensim) (1.19.2)\n",
|
|
||||||
"Requirement already satisfied: scipy>=0.18.1 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from gensim) (1.5.2)\n",
|
|
||||||
"Requirement already satisfied: Cython==0.29.14 in c:\\users\\annad\\anaconda3\\lib\\site-packages (from gensim) (0.29.14)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"!pip install gensim"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from gensim.test.utils import common_texts\n",
|
|
||||||
"from gensim.models import Word2Vec\n",
|
|
||||||
"import os.path"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import gzip\n",
|
|
||||||
"import shutil\n",
|
|
||||||
"with gzip.open('train/train.tsv.gz', 'rb') as f_in:\n",
|
|
||||||
" with open('train/train.tsv', 'wb') as f_out:\n",
|
|
||||||
" shutil.copyfileobj(f_in, f_out)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 18,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>Ball</th>\n",
|
|
||||||
" <th>Text</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Mindaugas Budzinauskas wierzy w odbudowę formy...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Przyjmujący reprezentacji Polski wrócił do PGE...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" <td>FEN 9: Zapowiedź walki Róża Gumienna vs Katarz...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Aleksander Filipiak: Czuję się dobrze w nowym ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" <td>Victoria Carl i Aleksiej Czerwotkin mistrzami ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98127</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Kamil Syprzak zaczyna kolekcjonować trofea. FC...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98128</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Holandia: dwa gole Piotra Parzyszka Piotr Parz...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98129</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Sparingowo: Korona gorsza od Stali. Lettieri s...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98130</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Vive - Wisła. Ośmiu debiutantów w tegorocznej ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98131</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>WTA Miami: Timea Bacsinszky pokonana, Swietłan...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>98132 rows × 2 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" Ball Text\n",
|
|
||||||
"0 1 Mindaugas Budzinauskas wierzy w odbudowę formy...\n",
|
|
||||||
"1 1 Przyjmujący reprezentacji Polski wrócił do PGE...\n",
|
|
||||||
"2 0 FEN 9: Zapowiedź walki Róża Gumienna vs Katarz...\n",
|
|
||||||
"3 1 Aleksander Filipiak: Czuję się dobrze w nowym ...\n",
|
|
||||||
"4 0 Victoria Carl i Aleksiej Czerwotkin mistrzami ...\n",
|
|
||||||
"... ... ...\n",
|
|
||||||
"98127 1 Kamil Syprzak zaczyna kolekcjonować trofea. FC...\n",
|
|
||||||
"98128 1 Holandia: dwa gole Piotra Parzyszka Piotr Parz...\n",
|
|
||||||
"98129 1 Sparingowo: Korona gorsza od Stali. Lettieri s...\n",
|
|
||||||
"98130 1 Vive - Wisła. Ośmiu debiutantów w tegorocznej ...\n",
|
|
||||||
"98131 1 WTA Miami: Timea Bacsinszky pokonana, Swietłan...\n",
|
|
||||||
"\n",
|
|
||||||
"[98132 rows x 2 columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 18,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"data = pd.read_csv('train/train.tsv', sep='\\t', names=[\"Ball\",\"Text\"])\n",
|
|
||||||
"data"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 21,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = None\n",
|
|
||||||
"if not os.path.isfile('word2vec.model'): \n",
|
|
||||||
" model = Word2Vec(sentences=data[\"Text\"], window=5, min_count=1, workers=5)\n",
|
|
||||||
" model.save(\"word2vec.model\")\n",
|
|
||||||
"else:"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
|
5452
dev-0/out.tsv
Normal file
5452
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,8 +9,12 @@
|
|||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"from gensim.test.utils import common_texts\n",
|
"from gensim.test.utils import common_texts\n",
|
||||||
"from gensim.models import Word2Vec\n",
|
"from gensim.models import FastText\n",
|
||||||
"import os.path"
|
"import os.path\n",
|
||||||
|
"import gzip\n",
|
||||||
|
"import shutil\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.optim as optim"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -19,8 +23,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import gzip\n",
|
"features = 100\n",
|
||||||
"import shutil\n",
|
"batch_size = 16\n",
|
||||||
|
"criterion = torch.nn.BCELoss()\n",
|
||||||
|
"\n",
|
||||||
"with gzip.open('train/train.tsv.gz', 'rb') as f_in:\n",
|
"with gzip.open('train/train.tsv.gz', 'rb') as f_in:\n",
|
||||||
" with open('train/train.tsv', 'wb') as f_out:\n",
|
" with open('train/train.tsv', 'wb') as f_out:\n",
|
||||||
" shutil.copyfileobj(f_in, f_out)"
|
" shutil.copyfileobj(f_in, f_out)"
|
||||||
@ -33,105 +39,19 @@
|
|||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>Ball</th>\n",
|
|
||||||
" <th>Text</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Mindaugas Budzinauskas wierzy w odbudowę formy...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Przyjmujący reprezentacji Polski wrócił do PGE...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" <td>FEN 9: Zapowiedź walki Róża Gumienna vs Katarz...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Aleksander Filipiak: Czuję się dobrze w nowym ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" <td>Victoria Carl i Aleksiej Czerwotkin mistrzami ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98127</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Kamil Syprzak zaczyna kolekcjonować trofea. FC...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98128</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Holandia: dwa gole Piotra Parzyszka Piotr Parz...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98129</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Sparingowo: Korona gorsza od Stali. Lettieri s...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98130</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>Vive - Wisła. Ośmiu debiutantów w tegorocznej ...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>98131</th>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" <td>WTA Miami: Timea Bacsinszky pokonana, Swietłan...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>98132 rows × 2 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
" Ball Text\n",
|
"0 [mindaugas, budzinauskas, wierzy, w, odbudowę,...\n",
|
||||||
"0 1 Mindaugas Budzinauskas wierzy w odbudowę formy...\n",
|
"1 [przyjmujący, reprezentacji, polski, wrócił, d...\n",
|
||||||
"1 1 Przyjmujący reprezentacji Polski wrócił do PGE...\n",
|
"2 [fen, 9:, zapowiedź, walki, róża, gumienna, vs...\n",
|
||||||
"2 0 FEN 9: Zapowiedź walki Róża Gumienna vs Katarz...\n",
|
"3 [aleksander, filipiak:, czuję, się, dobrze, w,...\n",
|
||||||
"3 1 Aleksander Filipiak: Czuję się dobrze w nowym ...\n",
|
"4 [victoria, carl, i, aleksiej, czerwotkin, mist...\n",
|
||||||
"4 0 Victoria Carl i Aleksiej Czerwotkin mistrzami ...\n",
|
" ... \n",
|
||||||
"... ... ...\n",
|
"98127 [kamil, syprzak, zaczyna, kolekcjonować, trofe...\n",
|
||||||
"98127 1 Kamil Syprzak zaczyna kolekcjonować trofea. FC...\n",
|
"98128 [holandia:, dwa, gole, piotra, parzyszka, piot...\n",
|
||||||
"98128 1 Holandia: dwa gole Piotra Parzyszka Piotr Parz...\n",
|
"98129 [sparingowo:, korona, gorsza, od, stali., lett...\n",
|
||||||
"98129 1 Sparingowo: Korona gorsza od Stali. Lettieri s...\n",
|
"98130 [vive, -, wisła., ośmiu, debiutantów, w, tegor...\n",
|
||||||
"98130 1 Vive - Wisła. Ośmiu debiutantów w tegorocznej ...\n",
|
"98131 [wta, miami:, timea, bacsinszky, pokonana,, sw...\n",
|
||||||
"98131 1 WTA Miami: Timea Bacsinszky pokonana, Swietłan...\n",
|
"Name: Text, Length: 98132, dtype: object"
|
||||||
"\n",
|
|
||||||
"[98132 rows x 2 columns]"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
@ -141,7 +61,8 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"data = pd.read_csv('train/train.tsv', sep='\\t', names=[\"Ball\",\"Text\"])\n",
|
"data = pd.read_csv('train/train.tsv', sep='\\t', names=[\"Ball\",\"Text\"])\n",
|
||||||
"data"
|
"data[\"Text\"] = data[\"Text\"].str.lower().str.split()\n",
|
||||||
|
"data[\"Text\"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -150,42 +71,162 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model = None\n",
|
"ft_model = None\n",
|
||||||
"sentences = [x.split() for x in data[\"Text\"]]\n",
|
"if not os.path.isfile('fasttext.model'):\n",
|
||||||
"if not os.path.isfile('word2vec.model'):\n",
|
" ft_model = FastText(size=features, window=3, min_count=1)\n",
|
||||||
" model = Word2Vec(sentences=data[\"Text\"])\n",
|
" ft_model.build_vocab(sentences=data[\"Text\"])\n",
|
||||||
" model.save(\"word2vec.model\")\n",
|
" ft_model.train(data[\"Text\"], total_examples=len(data[\"Text\"]), epochs=10)\n",
|
||||||
" model.train(sentences, total_examples=len(sentences), epochs=10)\n",
|
" ft_model.save(\"fasttext.model\")\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" model = Word2Vec.load(\"word2vec.model\")"
|
" ft_model = FastText.load(\"fasttext.model\")\n",
|
||||||
|
" \n",
|
||||||
|
"def document_vector(doc):\n",
|
||||||
|
" result = ft_model.wv[doc]\n",
|
||||||
|
" return np.max(result, axis=0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X = [document_vector(x) for x in data[\"Text\"]]\n",
|
||||||
|
"Y = data[\"Ball\"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class NeuralNetworkModel(torch.nn.Module):\n",
|
||||||
|
" def __init__(self):\n",
|
||||||
|
" super(NeuralNetworkModel, self).__init__()\n",
|
||||||
|
" self.fc1 = torch.nn.Linear(features,200)\n",
|
||||||
|
" self.fc2 = torch.nn.Linear(200,150)\n",
|
||||||
|
" self.fc3 = torch.nn.Linear(150,1)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x = self.fc1(x)\n",
|
||||||
|
" x = torch.relu(x)\n",
|
||||||
|
" x = self.fc2(x)\n",
|
||||||
|
" x = torch.sigmoid(x)\n",
|
||||||
|
" x = self.fc3(x)\n",
|
||||||
|
" x = torch.sigmoid(x)\n",
|
||||||
|
" return x\n",
|
||||||
|
"\n",
|
||||||
|
" \n",
|
||||||
|
"def get_loss_acc(model, X_dataset, Y_dataset):\n",
|
||||||
|
" loss_score = 0\n",
|
||||||
|
" acc_score = 0\n",
|
||||||
|
" items_total = 0\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" for i in range(0, Y_dataset.shape[0], batch_size):\n",
|
||||||
|
" x = X_dataset[i:i+batch_size]\n",
|
||||||
|
" x = torch.tensor(x)\n",
|
||||||
|
" y = Y_dataset[i:i+batch_size]\n",
|
||||||
|
" y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1,1)\n",
|
||||||
|
" y_predictions = model(x)\n",
|
||||||
|
" acc_score += torch.sum((y_predictions >= 0.5) == y).item()\n",
|
||||||
|
" items_total += y.shape[0] \n",
|
||||||
|
"\n",
|
||||||
|
" loss = criterion(y_predictions, y)\n",
|
||||||
|
"\n",
|
||||||
|
" loss_score += loss.item() * y.shape[0] \n",
|
||||||
|
" return (loss_score / items_total), (acc_score / items_total)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_path = 'nn.model'\n",
|
||||||
|
"nn_model = NeuralNetworkModel()\n",
|
||||||
|
" \n",
|
||||||
|
"if not os.path.isfile(model_path):\n",
|
||||||
|
" optimizer = optim.SGD(nn_model.parameters(), lr=0.1)\n",
|
||||||
|
"\n",
|
||||||
|
" display(get_loss_acc(nn_model, X, Y))\n",
|
||||||
|
" for epoch in range(5):\n",
|
||||||
|
" nn_model.train()\n",
|
||||||
|
" for i in range(0, len(X), batch_size):\n",
|
||||||
|
" x = X[i:i+batch_size]\n",
|
||||||
|
" x = torch.tensor(x)\n",
|
||||||
|
"\n",
|
||||||
|
" y = Y[i:i+batch_size]\n",
|
||||||
|
" y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1,1)\n",
|
||||||
|
"\n",
|
||||||
|
" y_predictions = nn_model(x)\n",
|
||||||
|
" loss = criterion(y_predictions, y)\n",
|
||||||
|
"\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" display(get_loss_acc(nn_model, X, Y))\n",
|
||||||
|
" torch.save(nn_model.state_dict(), model_path)\n",
|
||||||
|
"else:\n",
|
||||||
|
" nn_model.load_state_dict(torch.load(model_path))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"x_dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=[\"Text\"])[\"Text\"]\n",
|
||||||
|
"y_dev = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=[\"Ball\"])[\"Ball\"]\n",
|
||||||
|
"x_dev = [document_vector(x) for x in x_dev.str.lower().str.split()]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"ename": "KeyError",
|
"data": {
|
||||||
"evalue": "\"word 'Mindaugas' not in vocabulary\"",
|
"text/plain": [
|
||||||
"output_type": "error",
|
"(0.45761072419184756, 0.7694424064563463)"
|
||||||
"traceback": [
|
|
||||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"\u001b[1;32m<ipython-input-6-dec2e93bf676>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mprepared_training_data\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Text'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mprepared_training_data\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Text'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwv\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
||||||
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\pandas\\core\\series.py\u001b[0m in \u001b[0;36mapply\u001b[1;34m(self, func, convert_dtype, args, **kwds)\u001b[0m\n\u001b[0;32m 4198\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4199\u001b[0m \u001b[0mvalues\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobject\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 4200\u001b[1;33m \u001b[0mmapped\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmap_infer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconvert\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconvert_dtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4201\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4202\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmapped\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmapped\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mSeries\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[1;32mpandas\\_libs\\lib.pyx\u001b[0m in \u001b[0;36mpandas._libs.lib.map_infer\u001b[1;34m()\u001b[0m\n",
|
|
||||||
"\u001b[1;32m<ipython-input-6-dec2e93bf676>\u001b[0m in \u001b[0;36m<lambda>\u001b[1;34m(x)\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mprepared_training_data\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Text'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mprepared_training_data\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Text'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwv\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
||||||
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\gensim\\models\\keyedvectors.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, entities)\u001b[0m\n\u001b[0;32m 353\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_vector\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mentities\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 354\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 355\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mvstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_vector\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mentity\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mentity\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mentities\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 356\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 357\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__contains__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mentity\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\gensim\\models\\keyedvectors.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 353\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_vector\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mentities\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 354\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 355\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mvstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_vector\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mentity\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mentity\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mentities\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 356\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 357\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__contains__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mentity\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\gensim\\models\\keyedvectors.py\u001b[0m in \u001b[0;36mget_vector\u001b[1;34m(self, word)\u001b[0m\n\u001b[0;32m 469\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 470\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mget_vector\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mword\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 471\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mword_vec\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mword\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 472\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 473\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwords_closer_than\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\gensim\\models\\keyedvectors.py\u001b[0m in \u001b[0;36mword_vec\u001b[1;34m(self, word, use_norm)\u001b[0m\n\u001b[0;32m 466\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 467\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 468\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"word '%s' not in vocabulary\"\u001b[0m \u001b[1;33m%\u001b[0m \u001b[0mword\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 469\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 470\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mget_vector\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mword\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[1;31mKeyError\u001b[0m: \"word 'Mindaugas' not in vocabulary\""
|
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"prepared_training_data['Text'] = prepared_training_data['Text'].apply(lambda x: model.wv[x.split()])"
|
"get_loss_acc(nn_model, x_dev, y_dev)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"y_dev_prediction = nn_model(torch.tensor(x_dev))\n",
|
||||||
|
"y_dev_prediction = np.array([round(y) for y in y_dev_prediction.flatten().tolist()])\n",
|
||||||
|
"np.savetxt(\"dev-0/out.tsv\", y_dev_prediction, fmt='%d')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"x_test = pd.read_csv('test-A/in.tsv', sep='\\t', names=[\"Text\"])[\"Text\"]\n",
|
||||||
|
"x_test = [document_vector(x) for x in x_test.str.lower().str.split()]\n",
|
||||||
|
"y_test_prediction = nn_model(torch.tensor(x_test))\n",
|
||||||
|
"y_test_prediction = np.array([round(y) for y in y_test_prediction.flatten().tolist()])\n",
|
||||||
|
"np.savetxt(\"test-A/out.tsv\", y_test_prediction, fmt='%d')"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
5447
test-A/out.tsv
Normal file
5447
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user