trigram nn
This commit is contained in:
parent
6ede4a21d3
commit
464857d8dd
@ -130,11 +130,11 @@
|
|||||||
"execution_count": 7,
|
"execution_count": 7,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def prediction(words, model) -> str:\n",
|
"def prediction(words, model, top) -> str:\n",
|
||||||
" words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n",
|
" words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n",
|
||||||
" ixs = torch.tensor(words_tensor).view(-1).to(device)\n",
|
" ixs = torch.tensor(words_tensor).view(-1).to(device)\n",
|
||||||
" out = model(ixs)\n",
|
" out = model(ixs)\n",
|
||||||
" top = torch.topk(out[0], 5)\n",
|
" top = torch.topk(out[0], top)\n",
|
||||||
" top_indices = top.indices.tolist()\n",
|
" top_indices = top.indices.tolist()\n",
|
||||||
" top_probs = top.values.tolist()\n",
|
" top_probs = top.values.tolist()\n",
|
||||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||||
@ -158,14 +158,14 @@
|
|||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def create_outputs(folder_name, model):\n",
|
"def create_outputs(folder_name, model, top):\n",
|
||||||
" print(f'Creating outputs in {folder_name}')\n",
|
" print(f'Creating outputs in {folder_name}')\n",
|
||||||
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
||||||
" with open(f'{folder_name}/out-EMBED_SIZE={embed_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
" with open(f'{folder_name}/out-top={top}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
||||||
" for line in fid:\n",
|
" for line in fid:\n",
|
||||||
" separated = line.split('\\t')\n",
|
" separated = line.split('\\t')\n",
|
||||||
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
|
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
|
||||||
" output_line = prediction(prefix, model)\n",
|
" output_line = prediction(prefix, model, top)\n",
|
||||||
" f.write(output_line + '\\n')"
|
" f.write(output_line + '\\n')"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -177,10 +177,10 @@
|
|||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def train_model():\n",
|
"def train_model(lr):\n",
|
||||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||||
" data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
" data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||||
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n",
|
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
|
||||||
" criterion = torch.nn.NLLLoss()\n",
|
" criterion = torch.nn.NLLLoss()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" model.train()\n",
|
" model.train()\n",
|
||||||
@ -215,16 +215,13 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def with_hyperparams():\n",
|
"def with_hyperparams():\n",
|
||||||
" for e_size in [200, 300]:\n",
|
" train_model(lr=0.0001)\n",
|
||||||
" global embed_size\n",
|
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||||
" embed_size = e_size\n",
|
" model.load_state_dict(torch.load(path_to_model))\n",
|
||||||
" train_model()\n",
|
" model.eval()\n",
|
||||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
" for top in [200, 400, 600]:\n",
|
||||||
" model.load_state_dict(torch.load(path_to_model))\n",
|
" create_outputs('dev-0', model, top)\n",
|
||||||
" model.eval()\n",
|
" create_outputs('test-A', model, top)"
|
||||||
"\n",
|
|
||||||
" create_outputs('dev-0', model)\n",
|
|
||||||
" create_outputs('test-A', model)"
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "kdjy-pX9dzWX"
|
"id": "kdjy-pX9dzWX"
|
||||||
@ -314,9 +311,9 @@
|
|||||||
"execution_count": 13,
|
"execution_count": 13,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"vocab_size = 30000\n",
|
"vocab_size = 25000\n",
|
||||||
"embed_size = 200\n",
|
"embed_size = 300\n",
|
||||||
"hidden_size = 500\n",
|
"hidden_size = 150\n",
|
||||||
"batch_size = 2000\n",
|
"batch_size = 2000\n",
|
||||||
"device = 'cuda'\n",
|
"device = 'cuda'\n",
|
||||||
"path_to_train = 'train/in.tsv.xz'\n",
|
"path_to_train = 'train/in.tsv.xz'\n",
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
10519
dev-0/out-top=200.tsv
Normal file
10519
dev-0/out-top=200.tsv
Normal file
File diff suppressed because one or more lines are too long
10519
dev-0/out-top=400.tsv
Normal file
10519
dev-0/out-top=400.tsv
Normal file
File diff suppressed because one or more lines are too long
10519
dev-0/out-top=600.tsv
Normal file
10519
dev-0/out-top=600.tsv
Normal file
File diff suppressed because one or more lines are too long
@ -1,10 +1,14 @@
|
|||||||
description: nn trigram multiple outs
|
description: nn trigram multiple outs
|
||||||
tags:
|
tags:
|
||||||
- neural-network
|
- neural-network
|
||||||
- ngram
|
- trigram
|
||||||
params:
|
params:
|
||||||
epochs: 1
|
epochs: 1
|
||||||
learning-rate: 0.0001
|
learning-rate: 0.0001
|
||||||
|
vocab_size: 25000
|
||||||
|
embed_size: 300
|
||||||
|
hidden_size: 150
|
||||||
|
batch_size: 2000
|
||||||
unwanted-params:
|
unwanted-params:
|
||||||
- model-file
|
- model-file
|
||||||
- vocab-file
|
- vocab-file
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
7414
test-A/out-top=200.tsv
Normal file
7414
test-A/out-top=200.tsv
Normal file
File diff suppressed because one or more lines are too long
7414
test-A/out-top=400.tsv
Normal file
7414
test-A/out-top=400.tsv
Normal file
File diff suppressed because one or more lines are too long
7414
test-A/out-top=600.tsv
Normal file
7414
test-A/out-top=600.tsv
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user