fix correct version
This commit is contained in:
parent
1f224ccd28
commit
318b2ca8b3
@ -1,3 +1,3 @@
|
||||
## challenging-america-word-gap-prediction
|
||||
### using simple trigram nn
|
||||
calculated perplexity: 653.89
|
||||
calculated perplexity: 349.16
|
||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
15
run.py
15
run.py
@ -17,12 +17,12 @@ import torch
|
||||
|
||||
from tqdm.notebook import tqdm
|
||||
|
||||
embed_size = 30
|
||||
vocab_size = 10_000
|
||||
num_epochs = 2
|
||||
embed_size = 300
|
||||
vocab_size = 30_000
|
||||
num_epochs = 1
|
||||
device = 'cuda'
|
||||
batch_size = 8192
|
||||
train_file_path = 'train/nano.txt'
|
||||
train_file_path = 'train/train.txt'
|
||||
|
||||
with open(train_file_path, 'r', encoding='utf-8') as file:
|
||||
total = len(file.readlines())
|
||||
@ -177,8 +177,13 @@ def predictor(prefix):
|
||||
|
||||
|
||||
def generate_result(input_path, output_path='out.tsv'):
|
||||
lines = []
|
||||
with open(input_path, encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for line in f:
|
||||
columns = line.split('\t')
|
||||
prefix = columns[6]
|
||||
suffix = columns[7]
|
||||
lines.append(prefix)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as output_file:
|
||||
for line in lines:
|
||||
|
@ -35,12 +35,12 @@
|
||||
"\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"\n",
|
||||
"embed_size = 30\n",
|
||||
"vocab_size = 10_000\n",
|
||||
"num_epochs = 2\n",
|
||||
"embed_size = 300\n",
|
||||
"vocab_size = 30_000\n",
|
||||
"num_epochs = 1\n",
|
||||
"device = 'cuda'\n",
|
||||
"batch_size = 8192\n",
|
||||
"train_file_path = 'train/nano.txt'\n",
|
||||
"train_file_path = 'train/train.txt'\n",
|
||||
"\n",
|
||||
"with open(train_file_path, 'r', encoding='utf-8') as file:\n",
|
||||
" total = len(file.readlines())"
|
||||
@ -147,7 +147,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5e4b6ce6edf94b90a70d415d75be7eb6",
|
||||
"model_id": "c3d8f9d5b178490899934860a55c2508",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@ -162,37 +162,19 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 tensor(9.2450, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "044c2ea05e344306881002e34d89bd54",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Train loop: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 tensor(6.2669, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
||||
"0 tensor(10.3631, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||||
"5000 tensor(5.7081, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||||
"10000 tensor(5.5925, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||||
"15000 tensor(5.5097, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"SimpleTrigramNeuralLanguageModel(\n",
|
||||
" (embedding): Embedding(10000, 30)\n",
|
||||
" (linear1): Linear(in_features=60, out_features=30, bias=True)\n",
|
||||
" (linear2): Linear(in_features=30, out_features=10000, bias=True)\n",
|
||||
" (embedding): Embedding(30000, 300)\n",
|
||||
" (linear1): Linear(in_features=600, out_features=300, bias=True)\n",
|
||||
" (linear2): Linear(in_features=300, out_features=30000, bias=True)\n",
|
||||
" (softmax): Softmax(dim=1)\n",
|
||||
")"
|
||||
]
|
||||
@ -281,8 +263,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate_result(input_path, output_path='out.tsv'):\n",
|
||||
" lines = []\n",
|
||||
" with open(input_path, encoding='utf-8') as f:\n",
|
||||
" lines = f.readlines()\n",
|
||||
" for line in f:\n",
|
||||
" columns = line.split('\\t')\n",
|
||||
" prefix = columns[6]\n",
|
||||
" suffix = columns[7]\n",
|
||||
" lines.append(prefix)\n",
|
||||
"\n",
|
||||
" with open(output_path, 'w', encoding='utf-8') as output_file:\n",
|
||||
" for line in lines:\n",
|
||||
|
Loading…
Reference in New Issue
Block a user