paranormal-or-skeptic - pytorch NNet

This commit is contained in:
Cezary 2022-05-25 23:06:35 +02:00
parent 2fc8abbc87
commit 42ede5e2c7
4 changed files with 360 additions and 247 deletions

16
Net.py Normal file
View File

@ -0,0 +1,16 @@
import torch.nn as nn
from torch import relu, sigmoid
class NNet(nn.Module):
def __init__(self):
super(NNet, self).__init__()
self.ll1 = nn.Linear(100, 1000)
self.ll2 = nn.Linear(1000, 400)
self.ll3 = nn.Linear(400, 1)
def forward(self, x):
x = relu(self.ll1(x))
x = relu(self.ll2(x))
x = sigmoid(self.ll3(x))
return x

View File

@ -34,7 +34,7 @@
0
0
0
1
0
0
1
0
@ -64,7 +64,7 @@
1
0
1
0
1
1
1
1
@ -162,7 +162,7 @@
1
0
1
1
0
0
0
0
@ -204,7 +204,7 @@
0
0
0
0
1
0
0
0
@ -255,7 +255,7 @@
1
0
0
0
1
0
0
0
@ -334,7 +334,7 @@
0
1
0
0
1
1
0
0
@ -371,7 +371,7 @@
0
0
0
0
1
0
0
1
@ -410,7 +410,7 @@
1
0
0
0
1
0
1
1
@ -447,11 +447,11 @@
0
0
0
0
1
1
0
0
1
0
0
0
0
@ -534,7 +534,7 @@
0
0
0
0
1
1
0
1
@ -568,7 +568,7 @@
0
0
0
0
1
0
1
0
@ -687,12 +687,12 @@
0
0
0
0
1
0
0
1
0
0
1
0
0
0
@ -737,7 +737,7 @@
1
0
0
0
1
0
1
0
@ -804,7 +804,7 @@
0
0
1
0
1
0
0
1
@ -812,7 +812,7 @@
0
0
1
0
1
0
0
1
@ -870,7 +870,7 @@
1
0
1
0
1
0
0
0
@ -959,7 +959,7 @@
0
0
0
1
0
1
1
0
@ -969,7 +969,7 @@
1
0
0
1
0
0
0
0
@ -1119,7 +1119,7 @@
1
0
0
1
0
1
0
0
@ -1241,7 +1241,7 @@
0
0
0
0
1
0
0
0
@ -1261,14 +1261,14 @@
0
1
0
0
1
0
0
0
0
0
1
0
1
0
0
0
@ -1485,14 +1485,14 @@
0
0
0
1
0
0
0
0
0
0
1
0
0
1
0
1
@ -1504,7 +1504,7 @@
0
0
0
1
0
0
1
0
@ -1570,13 +1570,13 @@
0
1
0
0
1
0
1
0
1
0
1
1
0
0
0
@ -1741,7 +1741,7 @@
0
1
0
0
1
0
0
0
@ -1780,7 +1780,7 @@
0
0
0
1
0
0
0
0
@ -1832,7 +1832,7 @@
0
0
0
0
1
0
0
0
@ -1855,7 +1855,7 @@
0
0
0
0
1
0
1
0
@ -1899,7 +1899,7 @@
1
0
0
1
0
0
0
1
@ -1932,7 +1932,7 @@
0
0
1
1
0
0
0
0
@ -1970,12 +1970,12 @@
1
0
0
1
0
0
0
0
0
1
0
1
0
@ -2005,7 +2005,6 @@
0
0
0
0
1
0
0
@ -2014,6 +2013,7 @@
0
0
0
0
1
0
0
@ -2054,7 +2054,7 @@
0
0
0
1
0
0
0
0
@ -2201,7 +2201,7 @@
0
0
1
1
0
0
0
0
@ -2254,7 +2254,7 @@
0
0
0
0
1
0
0
0
@ -2305,8 +2305,8 @@
1
1
0
1
0
1
0
0
0
@ -2345,7 +2345,7 @@
1
1
0
0
1
1
0
0
@ -2373,7 +2373,7 @@
0
0
0
0
1
0
0
0
@ -2416,20 +2416,20 @@
1
0
0
0
1
1
0
0
0
0
0
0
1
0
1
1
0
0
0
1
0
0
0
@ -2443,7 +2443,7 @@
0
0
1
0
1
0
0
0
@ -2452,7 +2452,7 @@
1
0
1
0
1
1
0
0
@ -2462,7 +2462,7 @@
0
0
0
1
0
0
0
0
@ -2494,7 +2494,7 @@
0
0
0
0
1
0
0
0
@ -2576,7 +2576,7 @@
0
0
0
1
0
1
1
0
@ -2591,7 +2591,7 @@
1
0
0
1
0
0
0
1
@ -2623,7 +2623,7 @@
1
1
0
1
0
0
1
0
@ -2651,7 +2651,7 @@
0
0
1
0
1
1
1
0
@ -2743,7 +2743,7 @@
0
1
0
1
0
1
0
0
@ -2771,7 +2771,7 @@
0
0
0
0
1
0
0
0
@ -2786,7 +2786,7 @@
0
0
0
0
1
0
0
0
@ -2839,7 +2839,7 @@
0
0
0
1
0
1
1
0
@ -2902,7 +2902,7 @@
0
1
0
0
1
0
0
0
@ -2931,7 +2931,7 @@
0
0
0
0
1
0
0
0
@ -2972,13 +2972,13 @@
0
0
0
1
0
0
0
0
0
1
0
0
0
0
0
@ -3050,7 +3050,7 @@
1
0
0
0
1
0
1
0
@ -3074,9 +3074,9 @@
0
0
0
1
0
0
1
0
0
0
@ -3088,7 +3088,7 @@
0
1
0
0
1
0
0
0
@ -3120,7 +3120,7 @@
1
0
0
0
1
0
0
0
@ -3139,7 +3139,7 @@
0
0
0
0
1
0
0
0
@ -3218,7 +3218,7 @@
0
0
0
0
1
0
0
0
@ -3286,7 +3286,7 @@
0
0
0
0
1
0
0
1
@ -3400,7 +3400,7 @@
0
0
1
1
0
0
0
0
@ -3452,7 +3452,7 @@
0
0
0
1
0
0
0
0
@ -3537,7 +3537,7 @@
0
0
0
1
0
1
0
0
@ -3556,7 +3556,7 @@
0
0
0
1
0
0
0
0
@ -3570,7 +3570,7 @@
1
1
1
1
0
1
0
0
@ -3585,14 +3585,14 @@
0
0
0
1
0
0
0
1
0
1
0
1
0
1
0
0
@ -3643,7 +3643,7 @@
0
0
0
1
0
0
1
1
@ -3658,7 +3658,7 @@
1
0
1
0
1
0
0
0
@ -3669,7 +3669,7 @@
0
1
0
1
0
0
0
0
@ -3690,10 +3690,10 @@
0
0
0
0
1
1
0
1
1
1
0
0
@ -3704,7 +3704,7 @@
0
0
1
1
0
0
1
0
@ -3738,12 +3738,12 @@
0
0
1
0
1
0
1
0
0
1
0
1
0
0
@ -3808,7 +3808,7 @@
0
1
0
1
0
0
0
0
@ -3839,7 +3839,7 @@
0
0
0
0
1
0
0
1
@ -3860,7 +3860,7 @@
0
0
0
0
1
0
0
0
@ -3959,7 +3959,7 @@
0
1
0
0
1
1
0
0
@ -3972,16 +3972,16 @@
0
0
0
0
0
1
0
1
0
0
0
0
0
0
1
0
1
0
@ -4046,7 +4046,7 @@
1
0
0
0
1
1
0
0
@ -4063,7 +4063,7 @@
0
1
0
0
1
0
0
0
@ -4088,7 +4088,7 @@
0
0
1
1
0
0
1
0
@ -4122,7 +4122,7 @@
0
0
0
0
1
0
0
0
@ -4139,7 +4139,7 @@
0
1
1
0
1
0
0
0
@ -4269,7 +4269,7 @@
0
1
1
0
1
1
0
0
@ -4354,7 +4354,7 @@
0
0
0
1
0
0
0
0
@ -4366,7 +4366,7 @@
1
1
1
0
1
0
1
0
@ -4374,7 +4374,7 @@
1
0
0
0
1
0
0
0
@ -4412,7 +4412,7 @@
0
0
1
0
1
0
1
0
@ -4503,7 +4503,7 @@
0
0
0
0
1
1
0
0
@ -4515,7 +4515,7 @@
0
0
0
0
1
1
1
0
@ -4524,14 +4524,14 @@
1
0
0
0
1
1
1
0
0
0
0
0
1
0
0
0
@ -4625,7 +4625,7 @@
0
0
0
0
1
0
0
0
@ -4789,7 +4789,7 @@
0
0
0
0
1
0
0
0
@ -4854,7 +4854,7 @@
0
0
1
1
0
0
0
0
@ -4871,7 +4871,7 @@
1
1
0
0
1
0
0
0
@ -4967,7 +4967,7 @@
0
1
0
0
1
1
1
0
@ -4982,7 +4982,7 @@
0
0
0
0
1
1
0
0
@ -4990,7 +4990,7 @@
0
1
0
0
1
0
1
0
@ -5028,7 +5028,7 @@
0
0
0
1
0
0
1
0
@ -5072,7 +5072,7 @@
0
0
1
1
0
1
0
1
@ -5087,7 +5087,7 @@
0
0
0
1
0
1
0
0
@ -5118,7 +5118,7 @@
1
0
0
0
1
1
0
0
@ -5145,7 +5145,7 @@
0
1
1
1
0
0
0
0
@ -5244,7 +5244,7 @@
1
0
0
0
1
0
0
0
@ -5260,7 +5260,7 @@
0
0
0
1
0
0
0
0

1 0
34 0
35 0
36 0
37 1 0
38 0
39 1
40 0
64 1
65 0
66 1
67 0 1
68 1
69 1
70 1
162 1
163 0
164 1
165 1 0
166 0
167 0
168 0
204 0
205 0
206 0
207 0 1
208 0
209 0
210 0
255 1
256 0
257 0
258 0 1
259 0
260 0
261 0
334 0
335 1
336 0
337 0 1
338 1
339 0
340 0
371 0
372 0
373 0
374 0 1
375 0
376 0
377 1
410 1
411 0
412 0
413 0 1
414 0
415 1
416 1
447 0
448 0
449 0
450 0 1
451 1
452 0
453 0
454 1 0
455 0
456 0
457 0
534 0
535 0
536 0
537 0 1
538 1
539 0
540 1
568 0
569 0
570 0
571 0 1
572 0
573 1
574 0
687 0
688 0
689 0
690 0 1
691 0
692 0
693 1
694 0
695 0 1
696 0
697 0
698 0
737 1
738 0
739 0
740 0 1
741 0
742 1
743 0
804 0
805 0
806 1
807 0 1
808 0
809 0
810 1
812 0
813 0
814 1
815 0 1
816 0
817 0
818 1
870 1
871 0
872 1
873 0 1
874 0
875 0
876 0
959 0
960 0
961 0
962 1 0
963 1
964 1
965 0
969 1
970 0
971 0
972 1 0
973 0
974 0
975 0
1119 1
1120 0
1121 0
1122 1 0
1123 1
1124 0
1125 0
1241 0
1242 0
1243 0
1244 0 1
1245 0
1246 0
1247 0
1261 0
1262 1
1263 0
1264 0 1
1265 0
1266 0
1267 0
1268 0
1269 0
1270 1
1271 0 1
1272 0
1273 0
1274 0
1485 0
1486 0
1487 0
1
1488 0
1489 0
1490 0
1491 0
1492 0
1493 0
1494 1 0
1495 0
1496 1
1497 0
1498 1
1504 0
1505 0
1506 0
1507 1 0
1508 0
1509 1
1510 0
1570 0
1571 1
1572 0
0
1573 1
0
1574 1
1575 0
1576 1
1577 0
1578 1
1579 1
1580 0
1581 0
1582 0
1741 0
1742 1
1743 0
1744 0 1
1745 0
1746 0
1747 0
1780 0
1781 0
1782 0
1783 1 0
1784 0
1785 0
1786 0
1832 0
1833 0
1834 0
1835 0 1
1836 0
1837 0
1838 0
1855 0
1856 0
1857 0
1858 0 1
1859 0
1860 1
1861 0
1899 1
1900 0
1901 0
1902 1 0
1903 0
1904 0
1905 1
1932 0
1933 0
1934 1
1935 1 0
1936 0
1937 0
1938 0
1970 1
1971 0
1972 0
1
1973 0
1974 0
1975 0
1976 0
1977 0
1978 1
1979 0
1980 1
1981 0
2005 0
2006 0
2007 0
0
2008 1
2009 0
2010 0
2013 0
2014 0
2015 0
2016 0
2017 1
2018 0
2019 0
2054 0
2055 0
2056 0
2057 1 0
2058 0
2059 0
2060 0
2201 0
2202 0
2203 1
2204 1 0
2205 0
2206 0
2207 0
2254 0
2255 0
2256 0
2257 0 1
2258 0
2259 0
2260 0
2305 1
2306 1
2307 0
1
2308 0
2309 1
2310 0
2311 0
2312 0
2345 1
2346 1
2347 0
2348 0 1
2349 1
2350 0
2351 0
2373 0
2374 0
2375 0
2376 0 1
2377 0
2378 0
2379 0
2416 1
2417 0
2418 0
2419 0 1
2420 1
2421 0
2422 0
2423 0
2424 0
2425 0
2426 0 1
2427 0
2428 1
2429 1
2430 0
2431 0
2432 0 1
2433 0
2434 0
2435 0
2443 0
2444 0
2445 1
2446 0 1
2447 0
2448 0
2449 0
2452 1
2453 0
2454 1
2455 0 1
2456 1
2457 0
2458 0
2462 0
2463 0
2464 0
2465 1 0
2466 0
2467 0
2468 0
2494 0
2495 0
2496 0
2497 0 1
2498 0
2499 0
2500 0
2576 0
2577 0
2578 0
2579 1 0
2580 1
2581 1
2582 0
2591 1
2592 0
2593 0
2594 1 0
2595 0
2596 0
2597 1
2623 1
2624 1
2625 0
2626 1 0
2627 0
2628 1
2629 0
2651 0
2652 0
2653 1
2654 0 1
2655 1
2656 1
2657 0
2743 0
2744 1
2745 0
2746 1 0
2747 1
2748 0
2749 0
2771 0
2772 0
2773 0
2774 0 1
2775 0
2776 0
2777 0
2786 0
2787 0
2788 0
2789 0 1
2790 0
2791 0
2792 0
2839 0
2840 0
2841 0
2842 1 0
2843 1
2844 1
2845 0
2902 0
2903 1
2904 0
2905 0 1
2906 0
2907 0
2908 0
2931 0
2932 0
2933 0
2934 0 1
2935 0
2936 0
2937 0
2972 0
2973 0
2974 0
1
2975 0
2976 0
2977 0
2978 0
2979 0
2980 1 0
2981 0
2982 0
2983 0
2984 0
3050 1
3051 0
3052 0
3053 0 1
3054 0
3055 1
3056 0
3074 0
3075 0
3076 0
1
3077 0
3078 0
3079 1
3080 0
3081 0
3082 0
3088 0
3089 1
3090 0
3091 0 1
3092 0
3093 0
3094 0
3120 1
3121 0
3122 0
3123 0 1
3124 0
3125 0
3126 0
3139 0
3140 0
3141 0
3142 0 1
3143 0
3144 0
3145 0
3218 0
3219 0
3220 0
3221 0 1
3222 0
3223 0
3224 0
3286 0
3287 0
3288 0
3289 0 1
3290 0
3291 0
3292 1
3400 0
3401 0
3402 1
3403 1 0
3404 0
3405 0
3406 0
3452 0
3453 0
3454 0
3455 1 0
3456 0
3457 0
3458 0
3537 0
3538 0
3539 0
3540 1 0
3541 1
3542 0
3543 0
3556 0
3557 0
3558 0
3559 1 0
3560 0
3561 0
3562 0
3570 1
3571 1
3572 1
3573 1 0
3574 1
3575 0
3576 0
3585 0
3586 0
3587 0
1
3588 0
3589 0
3590 0
1
3591 0
3592 1
3593 0
3594 1
3595 0
3596 1
3597 0
3598 0
3643 0
3644 0
3645 0
3646 1 0
3647 0
3648 1
3649 1
3658 1
3659 0
3660 1
3661 0 1
3662 0
3663 0
3664 0
3669 0
3670 1
3671 0
3672 1 0
3673 0
3674 0
3675 0
3690 0
3691 0
3692 0
0
3693 1
3694 1
3695 0 1
3696 1
3697 1
3698 0
3699 0
3704 0
3705 0
3706 1
3707 1 0
3708 0
3709 1
3710 0
3738 0
3739 0
3740 1
3741 0 1
3742 0
3743 1
3744 0
3745 0
3746 1 0
3747 1
3748 0
3749 0
3808 0
3809 1
3810 0
3811 1 0
3812 0
3813 0
3814 0
3839 0
3840 0
3841 0
3842 0 1
3843 0
3844 0
3845 1
3860 0
3861 0
3862 0
3863 0 1
3864 0
3865 0
3866 0
3959 0
3960 1
3961 0
3962 0 1
3963 1
3964 0
3965 0
3972 0
3973 0
3974 0
0
0
3975 1
3976 0
3977 1
3978 0
3979 0
3980 0
3981 0
3982 0
3983 0
3984 1
3985 0
3986 1
3987 0
4046 1
4047 0
4048 0
4049 0 1
4050 1
4051 0
4052 0
4063 0
4064 1
4065 0
4066 0 1
4067 0
4068 0
4069 0
4088 0
4089 0
4090 1
4091 1 0
4092 0
4093 1
4094 0
4122 0
4123 0
4124 0
4125 0 1
4126 0
4127 0
4128 0
4139 0
4140 1
4141 1
4142 0 1
4143 0
4144 0
4145 0
4269 0
4270 1
4271 1
4272 0 1
4273 1
4274 0
4275 0
4354 0
4355 0
4356 0
4357 1 0
4358 0
4359 0
4360 0
4366 1
4367 1
4368 1
4369 0 1
4370 0
4371 1
4372 0
4374 1
4375 0
4376 0
4377 0 1
4378 0
4379 0
4380 0
4412 0
4413 0
4414 1
4415 0 1
4416 0
4417 1
4418 0
4503 0
4504 0
4505 0
4506 0 1
4507 1
4508 0
4509 0
4515 0
4516 0
4517 0
4518 0 1
4519 1
4520 1
4521 0
4524 1
4525 0
4526 0
4527 0
4528 1
4529 1
1
0
4530 0
4531 0
4532 0
4533 0
4534 1
4535 0
4536 0
4537 0
4625 0
4626 0
4627 0
4628 0 1
4629 0
4630 0
4631 0
4789 0
4790 0
4791 0
4792 0 1
4793 0
4794 0
4795 0
4854 0
4855 0
4856 1
4857 1 0
4858 0
4859 0
4860 0
4871 1
4872 1
4873 0
4874 0 1
4875 0
4876 0
4877 0
4967 0
4968 1
4969 0
4970 0 1
4971 1
4972 1
4973 0
4982 0
4983 0
4984 0
4985 0 1
4986 1
4987 0
4988 0
4990 0
4991 1
4992 0
4993 0 1
4994 0
4995 1
4996 0
5028 0
5029 0
5030 0
5031 1 0
5032 0
5033 1
5034 0
5072 0
5073 0
5074 1
5075 1 0
5076 1
5077 0
5078 1
5087 0
5088 0
5089 0
5090 1 0
5091 1
5092 0
5093 0
5118 1
5119 0
5120 0
5121 0 1
5122 1
5123 0
5124 0
5145 0
5146 1
5147 1
5148 1 0
5149 0
5150 0
5151 0
5244 1
5245 0
5246 0
5247 0 1
5248 0
5249 0
5250 0
5260 0
5261 0
5262 0
5263 1 0
5264 0
5265 0
5266 0

97
run.py Normal file
View File

@ -0,0 +1,97 @@
import gensim.downloader
import torch.optim as optim
import torch.nn as nn
import torch
import numpy as np
from Net import NNet
#from timeit import default_timer as timer
def read_data(folder_name):
with open(f'{folder_name}/in.tsv', encoding='utf-8') as file:
x = [line.lower().split()[:-2] for line in file.readlines()]
with open(f'{folder_name}/expected.tsv', encoding='utf-8') as file:
y = [int(line.split()[0]) for line in file.readlines()]
return x, y
def process_data(data, word2vec):
processed_data = []
for reddit in data:
words_sim = [word2vec[word] for word in reddit if word in word2vec]
processed_data.append(np.mean(words_sim or [np.zeros(100)], axis=0))
return processed_data
def predict(folder_name, model, word_vec):
with open(f'{folder_name}/in.tsv', encoding='utf-8') as file:
x_data = [line.lower().split()[:-2] for line in file.readlines()]
x_train = process_data(x_data, word_vec)
y_predictions = []
with torch.no_grad():
for i, inputs in enumerate(x_train):
inputs = torch.tensor(inputs.astype(np.float32)).to(device)
y_predicted = model(inputs)
y_predictions.append(y_predicted > 0.5)
return y_predictions
def save_predictions(folder_name, predicted_labels):
predictions = []
for pred in predicted_labels:
predictions.append(pred.int()[0].item())
with open(f"{folder_name}/out.tsv", "w", encoding="UTF-8") as file_out:
for pred in predictions:
file_out.writelines(f"{str(pred)}\n")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device) #gpu is a bit faster here
word_vectors = gensim.downloader.load("glove-wiki-gigaword-100")
x_data, y_train = read_data('train')
x_train = process_data(x_data, word_vectors)
model = NNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.005) #, momentum=0.9)
for epoch in range(2):
running_loss = 0.0
correct = 0.
total = 0.
for i, (inputs, label) in enumerate(zip(x_train, y_train)):
inputs = torch.tensor(inputs.astype(np.float32)).to(device)
label = torch.tensor(np.array(label).astype(np.float32)).reshape(1).to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
y_predicted = model(inputs)
loss = criterion(y_predicted, label)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
correct += ((y_predicted > 0.5) == label).type(torch.float).sum().item()
total += label.size(0)
if i % 10000 == 9999: # print every 10000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10000:.3f}')
print(f'Accuracy score: {100 * correct / total} %')
running_loss = 0.0
predicted = predict('dev-0', model, word_vectors)
save_predictions('dev-0', predicted)
predicted = predict('test-A', model, word_vectors)
save_predictions('test-A', predicted)

File diff suppressed because it is too large Load Diff