trigram nn

This commit is contained in:
Krystian Wasilewski 2023-05-07 22:56:23 +02:00
parent 6ede4a21d3
commit 464857d8dd
12 changed files with 53821 additions and 35887 deletions

View File

@ -130,11 +130,11 @@
"execution_count": 7, "execution_count": 7,
"outputs": [], "outputs": [],
"source": [ "source": [
"def prediction(words, model) -> str:\n", "def prediction(words, model, top) -> str:\n",
" words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n", " words_tensor = [train_dataset.vocab.forward([word]) for word in words]\n",
" ixs = torch.tensor(words_tensor).view(-1).to(device)\n", " ixs = torch.tensor(words_tensor).view(-1).to(device)\n",
" out = model(ixs)\n", " out = model(ixs)\n",
" top = torch.topk(out[0], 5)\n", " top = torch.topk(out[0], top)\n",
" top_indices = top.indices.tolist()\n", " top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n", " top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n", " top_words = vocab.lookup_tokens(top_indices)\n",
@ -158,14 +158,14 @@
"execution_count": 8, "execution_count": 8,
"outputs": [], "outputs": [],
"source": [ "source": [
"def create_outputs(folder_name, model):\n", "def create_outputs(folder_name, model, top):\n",
" print(f'Creating outputs in {folder_name}')\n", " print(f'Creating outputs in {folder_name}')\n",
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
" with open(f'{folder_name}/out-EMBED_SIZE={embed_size}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n", " with open(f'{folder_name}/out-top={top}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
" for line in fid:\n", " for line in fid:\n",
" separated = line.split('\\t')\n", " separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n", " prefix = separated[6].replace(r'\\n', ' ').split()[-2:]\n",
" output_line = prediction(prefix, model)\n", " output_line = prediction(prefix, model, top)\n",
" f.write(output_line + '\\n')" " f.write(output_line + '\\n')"
], ],
"metadata": { "metadata": {
@ -177,10 +177,10 @@
"execution_count": 9, "execution_count": 9,
"outputs": [], "outputs": [],
"source": [ "source": [
"def train_model():\n", "def train_model(lr):\n",
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", " model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" data = DataLoader(train_dataset, batch_size=batch_size)\n", " data = DataLoader(train_dataset, batch_size=batch_size)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n", " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
" criterion = torch.nn.NLLLoss()\n", " criterion = torch.nn.NLLLoss()\n",
"\n", "\n",
" model.train()\n", " model.train()\n",
@ -215,16 +215,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def with_hyperparams():\n", "def with_hyperparams():\n",
" for e_size in [200, 300]:\n", " train_model(lr=0.0001)\n",
" global embed_size\n", " model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" embed_size = e_size\n", " model.load_state_dict(torch.load(path_to_model))\n",
" train_model()\n", " model.eval()\n",
" model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", " for top in [200, 400, 600]:\n",
" model.load_state_dict(torch.load(path_to_model))\n", " create_outputs('dev-0', model, top)\n",
" model.eval()\n", " create_outputs('test-A', model, top)"
"\n",
" create_outputs('dev-0', model)\n",
" create_outputs('test-A', model)"
], ],
"metadata": { "metadata": {
"id": "kdjy-pX9dzWX" "id": "kdjy-pX9dzWX"
@ -314,9 +311,9 @@
"execution_count": 13, "execution_count": 13,
"outputs": [], "outputs": [],
"source": [ "source": [
"vocab_size = 30000\n", "vocab_size = 25000\n",
"embed_size = 200\n", "embed_size = 300\n",
"hidden_size = 500\n", "hidden_size = 150\n",
"batch_size = 2000\n", "batch_size = 2000\n",
"device = 'cuda'\n", "device = 'cuda'\n",
"path_to_train = 'train/in.tsv.xz'\n", "path_to_train = 'train/in.tsv.xz'\n",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

10519
dev-0/out-top=200.tsv Normal file

File diff suppressed because one or more lines are too long

10519
dev-0/out-top=400.tsv Normal file

File diff suppressed because one or more lines are too long

10519
dev-0/out-top=600.tsv Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,10 +1,14 @@
description: nn trigram multiple outs description: nn trigram multiple outs
tags: tags:
- neural-network - neural-network
- ngram - trigram
params: params:
epochs: 1 epochs: 1
learning-rate: 0.0001 learning-rate: 0.0001
vocab_size: 25000
embed_size: 300
hidden_size: 150
batch_size: 2000
unwanted-params: unwanted-params:
- model-file - model-file
- vocab-file - vocab-file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

7414
test-A/out-top=200.tsv Normal file

File diff suppressed because one or more lines are too long

7414
test-A/out-top=400.tsv Normal file

File diff suppressed because one or more lines are too long

7414
test-A/out-top=600.tsv Normal file

File diff suppressed because one or more lines are too long