paranormal-or-skeptic - pytorch NNet
This commit is contained in:
parent
2fc8abbc87
commit
42ede5e2c7
16
Net.py
Normal file
16
Net.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch.nn as nn
|
||||
from torch import relu, sigmoid
|
||||
|
||||
|
||||
class NNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(NNet, self).__init__()
|
||||
self.ll1 = nn.Linear(100, 1000)
|
||||
self.ll2 = nn.Linear(1000, 400)
|
||||
self.ll3 = nn.Linear(400, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = relu(self.ll1(x))
|
||||
x = relu(self.ll2(x))
|
||||
x = sigmoid(self.ll3(x))
|
||||
return x
|
236
dev-0/out.tsv
236
dev-0/out.tsv
@ -34,7 +34,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -64,7 +64,7 @@
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
@ -162,7 +162,7 @@
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -204,7 +204,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -255,7 +255,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -334,7 +334,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -371,7 +371,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -410,7 +410,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
@ -447,11 +447,11 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -534,7 +534,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
@ -568,7 +568,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -687,12 +687,12 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -737,7 +737,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -804,7 +804,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -812,7 +812,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -870,7 +870,7 @@
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -959,7 +959,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
@ -969,7 +969,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1119,7 +1119,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -1241,7 +1241,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1261,14 +1261,14 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1485,14 +1485,14 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
@ -1504,7 +1504,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -1570,13 +1570,13 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1741,7 +1741,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1780,7 +1780,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1832,7 +1832,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1855,7 +1855,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -1899,7 +1899,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -1932,7 +1932,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -1970,12 +1970,12 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -2005,7 +2005,6 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -2014,6 +2013,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -2054,7 +2054,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2201,7 +2201,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2254,7 +2254,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2305,8 +2305,8 @@
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2345,7 +2345,7 @@
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -2373,7 +2373,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2416,20 +2416,20 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2443,7 +2443,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2452,7 +2452,7 @@
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -2462,7 +2462,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2494,7 +2494,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2576,7 +2576,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
@ -2591,7 +2591,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -2623,7 +2623,7 @@
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -2651,7 +2651,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
@ -2743,7 +2743,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -2771,7 +2771,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2786,7 +2786,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2839,7 +2839,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
@ -2902,7 +2902,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2931,7 +2931,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -2972,13 +2972,13 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3050,7 +3050,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -3074,9 +3074,9 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3088,7 +3088,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3120,7 +3120,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3139,7 +3139,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3218,7 +3218,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3286,7 +3286,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -3400,7 +3400,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3452,7 +3452,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3537,7 +3537,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -3556,7 +3556,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3570,7 +3570,7 @@
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -3585,14 +3585,14 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -3643,7 +3643,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
@ -3658,7 +3658,7 @@
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3669,7 +3669,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3690,10 +3690,10 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -3704,7 +3704,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -3738,12 +3738,12 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -3808,7 +3808,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3839,7 +3839,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
@ -3860,7 +3860,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -3959,7 +3959,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -3972,16 +3972,16 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -4046,7 +4046,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -4063,7 +4063,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4088,7 +4088,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -4122,7 +4122,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4139,7 +4139,7 @@
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4269,7 +4269,7 @@
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -4354,7 +4354,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4366,7 +4366,7 @@
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -4374,7 +4374,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4412,7 +4412,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -4503,7 +4503,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -4515,7 +4515,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
@ -4524,14 +4524,14 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4625,7 +4625,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4789,7 +4789,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4854,7 +4854,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4871,7 +4871,7 @@
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -4967,7 +4967,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
@ -4982,7 +4982,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -4990,7 +4990,7 @@
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -5028,7 +5028,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
@ -5072,7 +5072,7 @@
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
@ -5087,7 +5087,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -5118,7 +5118,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
@ -5145,7 +5145,7 @@
|
||||
0
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -5244,7 +5244,7 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
@ -5260,7 +5260,7 @@
|
||||
0
|
||||
0
|
||||
0
|
||||
1
|
||||
0
|
||||
0
|
||||
0
|
||||
0
|
||||
|
|
97
run.py
Normal file
97
run.py
Normal file
@ -0,0 +1,97 @@
|
||||
import gensim.downloader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
from Net import NNet
|
||||
#from timeit import default_timer as timer
|
||||
|
||||
|
||||
def read_data(folder_name):
|
||||
with open(f'{folder_name}/in.tsv', encoding='utf-8') as file:
|
||||
x = [line.lower().split()[:-2] for line in file.readlines()]
|
||||
|
||||
with open(f'{folder_name}/expected.tsv', encoding='utf-8') as file:
|
||||
y = [int(line.split()[0]) for line in file.readlines()]
|
||||
return x, y
|
||||
|
||||
|
||||
def process_data(data, word2vec):
|
||||
processed_data = []
|
||||
for reddit in data:
|
||||
words_sim = [word2vec[word] for word in reddit if word in word2vec]
|
||||
processed_data.append(np.mean(words_sim or [np.zeros(100)], axis=0))
|
||||
return processed_data
|
||||
|
||||
|
||||
def predict(folder_name, model, word_vec):
|
||||
with open(f'{folder_name}/in.tsv', encoding='utf-8') as file:
|
||||
x_data = [line.lower().split()[:-2] for line in file.readlines()]
|
||||
|
||||
x_train = process_data(x_data, word_vec)
|
||||
|
||||
y_predictions = []
|
||||
with torch.no_grad():
|
||||
for i, inputs in enumerate(x_train):
|
||||
inputs = torch.tensor(inputs.astype(np.float32)).to(device)
|
||||
|
||||
y_predicted = model(inputs)
|
||||
y_predictions.append(y_predicted > 0.5)
|
||||
return y_predictions
|
||||
|
||||
|
||||
def save_predictions(folder_name, predicted_labels):
|
||||
predictions = []
|
||||
for pred in predicted_labels:
|
||||
predictions.append(pred.int()[0].item())
|
||||
|
||||
with open(f"{folder_name}/out.tsv", "w", encoding="UTF-8") as file_out:
|
||||
for pred in predictions:
|
||||
file_out.writelines(f"{str(pred)}\n")
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(device) #gpu is a bit faster here
|
||||
|
||||
word_vectors = gensim.downloader.load("glove-wiki-gigaword-100")
|
||||
|
||||
x_data, y_train = read_data('train')
|
||||
x_train = process_data(x_data, word_vectors)
|
||||
|
||||
model = NNet().to(device)
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.005) #, momentum=0.9)
|
||||
|
||||
for epoch in range(2):
|
||||
running_loss = 0.0
|
||||
correct = 0.
|
||||
total = 0.
|
||||
for i, (inputs, label) in enumerate(zip(x_train, y_train)):
|
||||
inputs = torch.tensor(inputs.astype(np.float32)).to(device)
|
||||
label = torch.tensor(np.array(label).astype(np.float32)).reshape(1).to(device)
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
y_predicted = model(inputs)
|
||||
loss = criterion(y_predicted, label)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
correct += ((y_predicted > 0.5) == label).type(torch.float).sum().item()
|
||||
total += label.size(0)
|
||||
|
||||
if i % 10000 == 9999: # print every 10000 mini-batches
|
||||
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10000:.3f}')
|
||||
print(f'Accuracy score: {100 * correct / total} %')
|
||||
running_loss = 0.0
|
||||
|
||||
predicted = predict('dev-0', model, word_vectors)
|
||||
save_predictions('dev-0', predicted)
|
||||
|
||||
predicted = predict('test-A', model, word_vectors)
|
||||
save_predictions('test-A', predicted)
|
258
test-A/out.tsv
258
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user