solution v2

This commit is contained in:
Łukasz Jędyk 2022-05-20 07:14:00 +02:00
parent ef5a97420d
commit eacc27a0d2
3 changed files with 53 additions and 17 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
.ipynb_checkpoints/ .ipynb_checkpoints/
processed_train.txt

BIN
geval

Binary file not shown.

View File

@ -17,7 +17,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 2,
"id": "74f269c9", "id": "74f269c9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -28,7 +28,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 3,
"id": "c939f600", "id": "c939f600",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -43,7 +43,10 @@
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n", " self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n", " self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
" \n", " \n",
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n", " self.index_to_word[-1] = '<UNK>'\n",
" self.word_to_index['<UNK>'] = -1\n",
"\n",
" self.words_indexes = [self.word_to_index[w] if w in self.uniq_words else self.word_to_index['<UNK>'] for w in self.words]\n",
"\n", "\n",
" def load(self):\n", " def load(self):\n",
" with open(self.file_path, 'r') as f_in:\n", " with open(self.file_path, 'r') as f_in:\n",
@ -54,7 +57,8 @@
" return text\n", " return text\n",
" \n", " \n",
" def get_uniq_words(self):\n", " def get_uniq_words(self):\n",
" word_counts = Counter(self.words)\n", " word_counts = Counter(self.words).most_common(250000)\n",
" word_counts = dict(word_counts)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n", " return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n", "\n",
" def __len__(self):\n", " def __len__(self):\n",
@ -69,7 +73,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 4,
"id": "9f178921", "id": "9f178921",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -118,8 +122,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n", "train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"\n", "\n",
"train_data = train_data[[6, 7]]\n", "train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)\n", "train_data = pd.concat([train_data, train_labels], axis=1)\n",
@ -135,19 +139,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 7,
"id": "bb55ce42", "id": "bb55ce42",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"ename": "NameError", "ename": "KeyboardInterrupt",
"evalue": "name 'dataset' is not defined", "evalue": "",
"output_type": "error", "output_type": "error",
"traceback": [ "traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_14895/2199368365.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'processed_train.txt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprocessed_train.txt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mNameError\u001b[0m: name 'dataset' is not defined" "Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36mDataset.__init__\u001b[1;34m(self, sequence_length, file_path)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mword_to_index\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
] ]
} }
], ],
@ -155,6 +161,17 @@
"data = Dataset(5, 'processed_train.txt')" "data = Dataset(5, 'processed_train.txt')"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "7a6872ca",
"metadata": {},
"outputs": [],
"source": [
"data_vocab_size = len(data.uniq_words)\n",
"print(data_vocab_size)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -162,7 +179,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model = Model(vocab_size = len(data.uniq_words)).to(device)" "model = Model(vocab_size = data_vocab_size).to(device)"
] ]
}, },
{ {
@ -225,8 +242,26 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#train(data, model, 1, 128)" "train(data, model, 10, 128)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df052344",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), 'model.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e87339b",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
@ -245,7 +280,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.10" "version": "3.9.2"
} }
}, },
"nbformat": 4, "nbformat": 4,