trigram nn

This commit is contained in:
Krystian Wasilewski 2023-05-07 22:56:23 +02:00
parent 6ede4a21d3
commit 464857d8dd
12 changed files with 53821 additions and 35887 deletions

View File

@ -130,11 +130,11 @@
"execution_count": 7,
"outputs": [],
"source": [
"def prediction(words, model) -> str:\n",
"def prediction(words, model, top) -> str:\n",
" words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n",
" ixs = torch.tensor(words_tensor).view(-1).to(device)\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], 5)\n",
" top = torch.topk(out[0], top)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
@ -158,14 +158,14 @@
"execution_count": 8,
"outputs": [],
"source": [
"def create_outputs(folder_name, model):\n",
"def create_outputs(folder_name, model, top):\n",
" print(f'Creating outputs in {folder_name}')\n",
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
" with open(f'{folder_name}/out-EMBED_SIZE={embed_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
" with open(f'{folder_name}/out-top={top}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
" for line in fid:\n",
" separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
" output_line = prediction(prefix, model)\n",
" output_line = prediction(prefix, model, top)\n",
" f.write(output_line + '\\n')"
],
"metadata": {
@ -177,10 +177,10 @@
"execution_count": 9,
"outputs": [],
"source": [
"def train_model():\n",
"def train_model(lr):\n",
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" data = DataLoader(train_dataset, batch_size=batch_size)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
" criterion = torch.nn.NLLLoss()\n",
"\n",
" model.train()\n",
@ -215,16 +215,13 @@
"outputs": [],
"source": [
"def with_hyperparams():\n",
" for e_size in [200, 300]:\n",
" global embed_size\n",
" embed_size = e_size\n",
" train_model()\n",
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" model.load_state_dict(torch.load(path_to_model))\n",
" model.eval()\n",
"\n",
" create_outputs('dev-0', model)\n",
" create_outputs('test-A', model)"
" train_model(lr=0.0001)\n",
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" model.load_state_dict(torch.load(path_to_model))\n",
" model.eval()\n",
" for top in [200, 400, 600]:\n",
" create_outputs('dev-0', model, top)\n",
" create_outputs('test-A', model, top)"
],
"metadata": {
"id": "kdjy-pX9dzWX"
@ -314,9 +311,9 @@
"execution_count": 13,
"outputs": [],
"source": [
"vocab_size = 30000\n",
"embed_size = 200\n",
"hidden_size = 500\n",
"vocab_size = 25000\n",
"embed_size = 300\n",
"hidden_size = 150\n",
"batch_size = 2000\n",
"device = 'cuda'\n",
"path_to_train = 'train/in.tsv.xz'\n",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

10519
dev-0/out-top=200.tsv Normal file

File diff suppressed because one or more lines are too long

10519
dev-0/out-top=400.tsv Normal file

File diff suppressed because one or more lines are too long

10519
dev-0/out-top=600.tsv Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,10 +1,14 @@
description: nn trigram multiple outs
tags:
- neural-network
- ngram
- trigram
params:
epochs: 1
learning-rate: 0.0001
vocab_size: 25000
embed_size: 300
hidden_size: 150
batch_size: 2000
unwanted-params:
- model-file
- vocab-file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

7414
test-A/out-top=200.tsv Normal file

File diff suppressed because one or more lines are too long

7414
test-A/out-top=400.tsv Normal file

File diff suppressed because one or more lines are too long

7414
test-A/out-top=600.tsv Normal file

File diff suppressed because one or more lines are too long