fix correct version

This commit is contained in:
kpierzynski 2024-05-22 03:25:24 +02:00
parent 1f224ccd28
commit 318b2ca8b3
4 changed files with 10548 additions and 10556 deletions

View File

@ -1,3 +1,3 @@
## challenging-america-word-gap-prediction ## challenging-america-word-gap-prediction
### using simple trigram nn ### using simple trigram nn
calculated perplexity: 653.89 calculated perplexity: 349.16

File diff suppressed because it is too large Load Diff

15
run.py
View File

@ -17,12 +17,12 @@ import torch
from tqdm.notebook import tqdm from tqdm.notebook import tqdm
embed_size = 30 embed_size = 300
vocab_size = 10_000 vocab_size = 30_000
num_epochs = 2 num_epochs = 1
device = 'cuda' device = 'cuda'
batch_size = 8192 batch_size = 8192
train_file_path = 'train/nano.txt' train_file_path = 'train/train.txt'
with open(train_file_path, 'r', encoding='utf-8') as file: with open(train_file_path, 'r', encoding='utf-8') as file:
total = len(file.readlines()) total = len(file.readlines())
@ -177,8 +177,13 @@ def predictor(prefix):
def generate_result(input_path, output_path='out.tsv'): def generate_result(input_path, output_path='out.tsv'):
lines = []
with open(input_path, encoding='utf-8') as f: with open(input_path, encoding='utf-8') as f:
lines = f.readlines() for line in f:
columns = line.split('\t')
prefix = columns[6]
suffix = columns[7]
lines.append(prefix)
with open(output_path, 'w', encoding='utf-8') as output_file: with open(output_path, 'w', encoding='utf-8') as output_file:
for line in lines: for line in lines:

View File

@ -35,12 +35,12 @@
"\n", "\n",
"from tqdm.notebook import tqdm\n", "from tqdm.notebook import tqdm\n",
"\n", "\n",
"embed_size = 30\n", "embed_size = 300\n",
"vocab_size = 10_000\n", "vocab_size = 30_000\n",
"num_epochs = 2\n", "num_epochs = 1\n",
"device = 'cuda'\n", "device = 'cuda'\n",
"batch_size = 8192\n", "batch_size = 8192\n",
"train_file_path = 'train/nano.txt'\n", "train_file_path = 'train/train.txt'\n",
"\n", "\n",
"with open(train_file_path, 'r', encoding='utf-8') as file:\n", "with open(train_file_path, 'r', encoding='utf-8') as file:\n",
" total = len(file.readlines())" " total = len(file.readlines())"
@ -147,7 +147,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "5e4b6ce6edf94b90a70d415d75be7eb6", "model_id": "c3d8f9d5b178490899934860a55c2508",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -162,37 +162,19 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"0 tensor(9.2450, device='cuda:0', grad_fn=<NllLossBackward0>)\n" "0 tensor(10.3631, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
] "5000 tensor(5.7081, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
}, "10000 tensor(5.5925, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
{ "15000 tensor(5.5097, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "044c2ea05e344306881002e34d89bd54",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(6.2669, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"SimpleTrigramNeuralLanguageModel(\n", "SimpleTrigramNeuralLanguageModel(\n",
" (embedding): Embedding(10000, 30)\n", " (embedding): Embedding(30000, 300)\n",
" (linear1): Linear(in_features=60, out_features=30, bias=True)\n", " (linear1): Linear(in_features=600, out_features=300, bias=True)\n",
" (linear2): Linear(in_features=30, out_features=10000, bias=True)\n", " (linear2): Linear(in_features=300, out_features=30000, bias=True)\n",
" (softmax): Softmax(dim=1)\n", " (softmax): Softmax(dim=1)\n",
")" ")"
] ]
@ -281,8 +263,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def generate_result(input_path, output_path='out.tsv'):\n", "def generate_result(input_path, output_path='out.tsv'):\n",
" lines = []\n",
" with open(input_path, encoding='utf-8') as f:\n", " with open(input_path, encoding='utf-8') as f:\n",
" lines = f.readlines()\n", " for line in f:\n",
" columns = line.split('\\t')\n",
" prefix = columns[6]\n",
" suffix = columns[7]\n",
" lines.append(prefix)\n",
"\n", "\n",
" with open(output_path, 'w', encoding='utf-8') as output_file:\n", " with open(output_path, 'w', encoding='utf-8') as output_file:\n",
" for line in lines:\n", " for line in lines:\n",