trigram nn
This commit is contained in:
parent
6ede4a21d3
commit
464857d8dd
|
@ -130,11 +130,11 @@
|
|||
"execution_count": 7,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prediction(words, model) -> str:\n",
|
||||
"def prediction(words, model, top) -> str:\n",
|
||||
" words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n",
|
||||
" ixs = torch.tensor(words_tensor).view(-1).to(device)\n",
|
||||
" out = model(ixs)\n",
|
||||
" top = torch.topk(out[0], 5)\n",
|
||||
" top = torch.topk(out[0], top)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
|
@ -158,14 +158,14 @@
|
|||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_outputs(folder_name, model):\n",
|
||||
"def create_outputs(folder_name, model, top):\n",
|
||||
" print(f'Creating outputs in {folder_name}')\n",
|
||||
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
||||
" with open(f'{folder_name}/out-EMBED_SIZE={embed_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
||||
" with open(f'{folder_name}/out-top={top}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
||||
" for line in fid:\n",
|
||||
" separated = line.split('\\t')\n",
|
||||
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
|
||||
" output_line = prediction(prefix, model)\n",
|
||||
" output_line = prediction(prefix, model, top)\n",
|
||||
" f.write(output_line + '\\n')"
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -177,10 +177,10 @@
|
|||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_model():\n",
|
||||
"def train_model(lr):\n",
|
||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n",
|
||||
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
|
||||
" criterion = torch.nn.NLLLoss()\n",
|
||||
"\n",
|
||||
" model.train()\n",
|
||||
|
@ -215,16 +215,13 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"def with_hyperparams():\n",
|
||||
" for e_size in [200, 300]:\n",
|
||||
" global embed_size\n",
|
||||
" embed_size = e_size\n",
|
||||
" train_model()\n",
|
||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(path_to_model))\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" create_outputs('dev-0', model)\n",
|
||||
" create_outputs('test-A', model)"
|
||||
" train_model(lr=0.0001)\n",
|
||||
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(path_to_model))\n",
|
||||
" model.eval()\n",
|
||||
" for top in [200, 400, 600]:\n",
|
||||
" create_outputs('dev-0', model, top)\n",
|
||||
" create_outputs('test-A', model, top)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "kdjy-pX9dzWX"
|
||||
|
@ -314,9 +311,9 @@
|
|||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 30000\n",
|
||||
"embed_size = 200\n",
|
||||
"hidden_size = 500\n",
|
||||
"vocab_size = 25000\n",
|
||||
"embed_size = 300\n",
|
||||
"hidden_size = 150\n",
|
||||
"batch_size = 2000\n",
|
||||
"device = 'cuda'\n",
|
||||
"path_to_train = 'train/in.tsv.xz'\n",
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,10 +1,14 @@
|
|||
description: nn trigram multiple outs
|
||||
tags:
|
||||
- neural-network
|
||||
- ngram
|
||||
- trigram
|
||||
params:
|
||||
epochs: 1
|
||||
learning-rate: 0.0001
|
||||
vocab_size: 25000
|
||||
embed_size: 300
|
||||
hidden_size: 150
|
||||
batch_size: 2000
|
||||
unwanted-params:
|
||||
- model-file
|
||||
- vocab-file
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue