solution v2

This commit is contained in:
Łukasz Jędyk 2022-05-20 07:14:00 +02:00
parent ef5a97420d
commit eacc27a0d2
3 changed files with 53 additions and 17 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
.ipynb_checkpoints/
processed_train.txt

BIN
geval

Binary file not shown.

View File

@ -17,7 +17,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"id": "74f269c9",
"metadata": {},
"outputs": [],
@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"id": "c939f600",
"metadata": {},
"outputs": [],
@ -43,7 +43,10 @@
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
" \n",
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n",
" self.index_to_word[-1] = '<UNK>'\n",
" self.word_to_index['<UNK>'] = -1\n",
"\n",
" self.words_indexes = [self.word_to_index[w] if w in self.uniq_words else self.word_to_index['<UNK>'] for w in self.words]\n",
"\n",
" def load(self):\n",
" with open(self.file_path, 'r') as f_in:\n",
@ -54,7 +57,8 @@
" return text\n",
" \n",
" def get_uniq_words(self):\n",
" word_counts = Counter(self.words)\n",
" word_counts = Counter(self.words).most_common(250000)\n",
" word_counts = dict(word_counts)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n",
" def __len__(self):\n",
@ -69,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 4,
"id": "9f178921",
"metadata": {},
"outputs": [],
@ -118,8 +122,8 @@
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"\n",
"train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
@ -135,19 +139,21 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"id": "bb55ce42",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'dataset' is not defined",
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_14895/2199368365.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'processed_train.txt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'dataset' is not defined"
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprocessed_train.txt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36mDataset.__init__\u001b[1;34m(self, sequence_length, file_path)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mword_to_index\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
@ -155,6 +161,17 @@
"data = Dataset(5, 'processed_train.txt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a6872ca",
"metadata": {},
"outputs": [],
"source": [
"data_vocab_size = len(data.uniq_words)\n",
"print(data_vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -162,7 +179,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(vocab_size = len(data.uniq_words)).to(device)"
"model = Model(vocab_size = data_vocab_size).to(device)"
]
},
{
@ -225,8 +242,26 @@
"metadata": {},
"outputs": [],
"source": [
"#train(data, model, 1, 128)"
"train(data, model, 10, 128)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df052344",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), 'model.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e87339b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -245,7 +280,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.9.2"
}
},
"nbformat": 4,