trigram nn
This commit is contained in:
parent
6ede4a21d3
commit
464857d8dd
@ -130,11 +130,11 @@
|
||||
"execution_count": 7,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prediction(words, model) -> str:\n",
|
||||
"def prediction(words, model, top) -> str:\n",
|
||||
" words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n",
|
||||
" ixs = torch.tensor(words_tensor).view(-1).to(device)\n",
|
||||
" out = model(ixs)\n",
|
||||
" top = torch.topk(out[0], 5)\n",
|
||||
" top = torch.topk(out[0], top)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
@ -158,14 +158,14 @@
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_outputs(folder_name, model):\n",
|
||||
"def create_outputs(folder_name, model, top):\n",
|
||||
" print(f'Creating outputs in {folder_name}')\n",
|
||||
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
||||
" with open(f'{folder_name}/out-EMBED_SIZE={embed_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
||||
" with open(f'{folder_name}/out-top={top}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
||||
" for line in fid:\n",
|
||||
" separated = line.split('\\t')\n",
|
||||
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
|
||||
" output_line = prediction(prefix, model)\n",
|
||||
" output_line = prediction(prefix, model, top)\n",
|
||||
" f.write(output_line + '\\n')"
|
||||
],
|
||||
"metadata": {
|
||||
@ -177,10 +177,10 @@
|
||||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_model():\n",
|
||||
"def train_model(lr):\n",
|
||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n",
|
||||
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
|
||||
" criterion = torch.nn.NLLLoss()\n",
|
||||
"\n",
|
||||
" model.train()\n",
|
||||
@ -215,16 +215,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def with_hyperparams():\n",
|
||||
" for e_size in [200, 300]:\n",
|
||||
" global embed_size\n",
|
||||
" embed_size = e_size\n",
|
||||
" train_model()\n",
|
||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(path_to_model))\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" create_outputs('dev-0', model)\n",
|
||||
" create_outputs('test-A', model)"
|
||||
" train_model(lr=0.0001)\n",
|
||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(path_to_model))\n",
|
||||
" model.eval()\n",
|
||||
" for top in [200, 400, 600]:\n",
|
||||
" create_outputs('dev-0', model, top)\n",
|
||||
" create_outputs('test-A', model, top)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "kdjy-pX9dzWX"
|
||||
@ -314,9 +311,9 @@
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 30000\n",
|
||||
"embed_size = 200\n",
|
||||
"hidden_size = 500\n",
|
||||
"vocab_size = 25000\n",
|
||||
"embed_size = 300\n",
|
||||
"hidden_size = 150\n",
|
||||
"batch_size = 2000\n",
|
||||
"device = 'cuda'\n",
|
||||
"path_to_train = 'train/in.tsv.xz'\n",
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
10519
dev-0/out-top=200.tsv
Normal file
10519
dev-0/out-top=200.tsv
Normal file
File diff suppressed because one or more lines are too long
10519
dev-0/out-top=400.tsv
Normal file
10519
dev-0/out-top=400.tsv
Normal file
File diff suppressed because one or more lines are too long
10519
dev-0/out-top=600.tsv
Normal file
10519
dev-0/out-top=600.tsv
Normal file
File diff suppressed because one or more lines are too long
@ -1,10 +1,14 @@
|
||||
description: nn trigram multiple outs
|
||||
tags:
|
||||
- neural-network
|
||||
- ngram
|
||||
- trigram
|
||||
params:
|
||||
epochs: 1
|
||||
learning-rate: 0.0001
|
||||
vocab_size: 25000
|
||||
embed_size: 300
|
||||
hidden_size: 150
|
||||
batch_size: 2000
|
||||
unwanted-params:
|
||||
- model-file
|
||||
- vocab-file
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
7414
test-A/out-top=200.tsv
Normal file
7414
test-A/out-top=200.tsv
Normal file
File diff suppressed because one or more lines are too long
7414
test-A/out-top=400.tsv
Normal file
7414
test-A/out-top=400.tsv
Normal file
File diff suppressed because one or more lines are too long
7414
test-A/out-top=600.tsv
Normal file
7414
test-A/out-top=600.tsv
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user