solution v2
This commit is contained in:
parent
ef5a97420d
commit
eacc27a0d2
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
.ipynb_checkpoints/
|
||||
processed_train.txt
|
69
run.ipynb
69
run.ipynb
@ -17,7 +17,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 2,
|
||||
"id": "74f269c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -28,7 +28,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 3,
|
||||
"id": "c939f600",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -42,8 +42,11 @@
|
||||
"\n",
|
||||
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
|
||||
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
|
||||
" \n",
|
||||
" self.index_to_word[-1] = '<UNK>'\n",
|
||||
" self.word_to_index['<UNK>'] = -1\n",
|
||||
"\n",
|
||||
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n",
|
||||
" self.words_indexes = [self.word_to_index[w] if w in self.uniq_words else self.word_to_index['<UNK>'] for w in self.words]\n",
|
||||
"\n",
|
||||
" def load(self):\n",
|
||||
" with open(self.file_path, 'r') as f_in:\n",
|
||||
@ -54,7 +57,8 @@
|
||||
" return text\n",
|
||||
" \n",
|
||||
" def get_uniq_words(self):\n",
|
||||
" word_counts = Counter(self.words)\n",
|
||||
" word_counts = Counter(self.words).most_common(250000)\n",
|
||||
" word_counts = dict(word_counts)\n",
|
||||
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
@ -69,7 +73,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 4,
|
||||
"id": "9f178921",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -118,8 +122,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
||||
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
|
||||
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
|
||||
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
|
||||
"\n",
|
||||
"train_data = train_data[[6, 7]]\n",
|
||||
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
|
||||
@ -135,19 +139,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 7,
|
||||
"id": "bb55ce42",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'dataset' is not defined",
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_14895/2199368365.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'processed_train.txt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'dataset' is not defined"
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprocessed_train.txt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36mDataset.__init__\u001b[1;34m(self, sequence_length, file_path)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
|
||||
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex_to_word[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords_indexes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mword_to_index[w] \u001b[38;5;28;01mif\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muniq_words \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mword_to_index\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m<UNK>\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwords]\n",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -155,6 +161,17 @@
|
||||
"data = Dataset(5, 'processed_train.txt')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7a6872ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_vocab_size = len(data.uniq_words)\n",
|
||||
"print(data_vocab_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -162,7 +179,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = Model(vocab_size = len(data.uniq_words)).to(device)"
|
||||
"model = Model(vocab_size = data_vocab_size).to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -225,8 +242,26 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#train(data, model, 1, 128)"
|
||||
"train(data, model, 10, 128)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "df052344",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.save(model.state_dict(), 'model.pt')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6e87339b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -245,7 +280,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.9.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
Reference in New Issue
Block a user