This commit is contained in:
kubapok 2021-05-10 10:55:55 +02:00
parent 3d70d8a7ec
commit 54ab26b5f9
2 changed files with 60 additions and 58 deletions

View File

@ -777,6 +777,7 @@
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" x = self.fc1(x)\n", " x = self.fc1(x)\n",
" x = torch.relu(x)\n",
" x = self.fc2(x)\n", " x = self.fc2(x)\n",
" x = torch.sigmoid(x)\n", " x = torch.sigmoid(x)\n",
" return x" " return x"

View File

@ -402,11 +402,11 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[0.4978],\n", "tensor([[0.4989],\n",
" [0.5009],\n", " [0.4985],\n",
" [0.4998],\n", " [0.4970],\n",
" [0.4990],\n", " [0.4968],\n",
" [0.5018]], grad_fn=<SigmoidBackward>)" " [0.5007]], grad_fn=<SigmoidBackward>)"
] ]
}, },
"execution_count": 20, "execution_count": 20,
@ -449,10 +449,10 @@
"data": { "data": {
"text/plain": [ "text/plain": [
"[Parameter containing:\n", "[Parameter containing:\n",
" tensor([[-0.0059, 0.0035, 0.0021, ..., -0.0042, -0.0057, -0.0049]],\n", " tensor([[ 0.0006, -0.0076, 0.0002, ..., 0.0051, 0.0034, -0.0004]],\n",
" requires_grad=True),\n", " requires_grad=True),\n",
" Parameter containing:\n", " Parameter containing:\n",
" tensor([-0.0023], requires_grad=True)]" " tensor([-0.0099], requires_grad=True)]"
] ]
}, },
"execution_count": 22, "execution_count": 22,
@ -556,10 +556,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[0.5667],\n", "tensor([[0.5657],\n",
" [0.5802],\n", " [0.5827],\n",
" [0.5757],\n", " [0.5727],\n",
" [0.5670]], grad_fn=<SigmoidBackward>)" " [0.5672]], grad_fn=<SigmoidBackward>)"
] ]
}, },
"execution_count": 28, "execution_count": 28,
@ -604,7 +604,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"452" "453"
] ]
}, },
"execution_count": 30, "execution_count": 30,
@ -645,7 +645,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"accuracy: 0.5587144622991347\n" "accuracy: 0.5599505562422744\n"
] ]
} }
], ],
@ -664,7 +664,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"BCE loss: 0.6745463597170355\n" "BCE loss: 0.6745760098965412\n"
] ]
} }
], ],
@ -717,7 +717,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.6443227143826974, 0.622991347342398)" "(0.6443268107837445, 0.6254635352286774)"
] ]
}, },
"execution_count": 35, "execution_count": 35,
@ -737,7 +737,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.6369243131743537, 0.6037037037037037)" "(0.6371536641209213, 0.6074074074074074)"
] ]
}, },
"execution_count": 36, "execution_count": 36,
@ -757,7 +757,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.6323775731785694, 0.6499302649930265)" "(0.6322633745447529, 0.6485355648535565)"
] ]
}, },
"execution_count": 37, "execution_count": 37,
@ -785,10 +785,10 @@
"data": { "data": {
"text/plain": [ "text/plain": [
"[Parameter containing:\n", "[Parameter containing:\n",
" tensor([[ 0.0314, -0.0375, 0.0131, ..., -0.0057, -0.0008, -0.0089]],\n", " tensor([[ 0.0379, -0.0485, 0.0113, ..., 0.0035, 0.0083, -0.0044]],\n",
" requires_grad=True),\n", " requires_grad=True),\n",
" Parameter containing:\n", " Parameter containing:\n",
" tensor([0.0563], requires_grad=True)]" " tensor([0.0556], requires_grad=True)]"
] ]
}, },
"execution_count": 38, "execution_count": 38,
@ -808,7 +808,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([ 0.0314, -0.0375, 0.0131, ..., -0.0057, -0.0008, -0.0089],\n", "tensor([ 0.0379, -0.0485, 0.0113, ..., 0.0035, 0.0083, -0.0044],\n",
" grad_fn=<SelectBackward>)" " grad_fn=<SelectBackward>)"
] ]
}, },
@ -830,11 +830,11 @@
"data": { "data": {
"text/plain": [ "text/plain": [
"torch.return_types.topk(\n", "torch.return_types.topk(\n",
"values=tensor([0.3753, 0.2305, 0.2007, 0.2006, 0.1993, 0.1952, 0.1930, 0.1898, 0.1831,\n", "values=tensor([0.3804, 0.2315, 0.2033, 0.2026, 0.2014, 0.1993, 0.1942, 0.1890, 0.1868,\n",
" 0.1731, 0.1649, 0.1647, 0.1543, 0.1320, 0.1314, 0.1303, 0.1296, 0.1261,\n", " 0.1818, 0.1727, 0.1542, 0.1474, 0.1458, 0.1360, 0.1359, 0.1260, 0.1204,\n",
" 0.1245, 0.1243], grad_fn=<TopkBackward>),\n", " 0.1184, 0.1174], grad_fn=<TopkBackward>),\n",
"indices=tensor([8942, 6336, 1852, 9056, 1865, 4039, 7820, 5002, 8208, 1857, 9709, 803,\n", "indices=tensor([8942, 6336, 1865, 1852, 8208, 9056, 7820, 4039, 5002, 1857, 9709, 803,\n",
" 1046, 130, 4306, 6481, 4370, 4259, 4285, 1855]))" " 130, 1046, 4370, 4259, 4306, 1855, 4285, 6481]))"
] ]
}, },
"execution_count": 40, "execution_count": 40,
@ -857,24 +857,24 @@
"text": [ "text": [
"the\n", "the\n",
"of\n", "of\n",
"christ\n",
"to\n",
"church\n", "church\n",
"god\n", "christ\n",
"rutgers\n",
"jesus\n",
"sin\n", "sin\n",
"to\n",
"rutgers\n",
"god\n",
"jesus\n",
"christians\n", "christians\n",
"we\n", "we\n",
"and\n", "and\n",
"athos\n",
"1993\n", "1993\n",
"hell\n", "athos\n",
"our\n",
"his\n", "his\n",
"he\n", "he\n",
"hell\n",
"christian\n",
"heaven\n", "heaven\n",
"christian\n" "our\n"
] ]
} }
], ],
@ -892,11 +892,11 @@
"data": { "data": {
"text/plain": [ "text/plain": [
"torch.return_types.topk(\n", "torch.return_types.topk(\n",
"values=tensor([-0.3478, -0.2578, -0.2455, -0.2347, -0.2330, -0.2265, -0.2205, -0.2050,\n", "values=tensor([-0.3464, -0.2578, -0.2372, -0.2307, -0.2300, -0.2259, -0.2227, -0.2107,\n",
" -0.2044, -0.1979, -0.1876, -0.1790, -0.1747, -0.1745, -0.1734, -0.1647,\n", " -0.2054, -0.1949, -0.1919, -0.1767, -0.1767, -0.1749, -0.1747, -0.1739,\n",
" -0.1639, -0.1617, -0.1601, -0.1592], grad_fn=<TopkBackward>),\n", " -0.1715, -0.1633, -0.1567, -0.1562], grad_fn=<TopkBackward>),\n",
"indices=tensor([5119, 8096, 5420, 4436, 6194, 1627, 6901, 5946, 9970, 3116, 1036, 9906,\n", "indices=tensor([5119, 8096, 5420, 1627, 6194, 6901, 4436, 9970, 5946, 3116, 1036, 9906,\n",
" 5654, 8329, 7869, 1039, 1991, 4926, 5035, 4925]))" " 7869, 5654, 1991, 8329, 4925, 4926, 6373, 1039]))"
] ]
}, },
"execution_count": 42, "execution_count": 42,
@ -922,23 +922,23 @@
"keith\n", "keith\n",
"sgi\n", "sgi\n",
"livesey\n", "livesey\n",
"host\n",
"nntp\n",
"caltech\n", "caltech\n",
"nntp\n",
"posting\n", "posting\n",
"morality\n", "host\n",
"you\n", "you\n",
"morality\n",
"edu\n", "edu\n",
"atheism\n", "atheism\n",
"wpd\n", "wpd\n",
"mathew\n",
"solntze\n",
"sandvik\n", "sandvik\n",
"atheists\n", "mathew\n",
"com\n", "com\n",
"solntze\n",
"islam\n",
"islamic\n", "islamic\n",
"jon\n", "okcforum\n",
"islam\n" "atheists\n"
] ]
} }
], ],
@ -969,6 +969,7 @@
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" x = self.fc1(x)\n", " x = self.fc1(x)\n",
" x = torch.relu(x)\n",
" x = self.fc2(x)\n", " x = self.fc2(x)\n",
" x = torch.sigmoid(x)\n", " x = torch.sigmoid(x)\n",
" return x" " return x"
@ -1029,7 +1030,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.6605833534551934, 0.5908529048207664)" "(0.6734723948651398, 0.5636588380716935)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1038,7 +1039,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.6379233609747004, 0.6481481481481481)" "(0.6606645694485417, 0.5777777777777777)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1056,7 +1057,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.4341224195120214, 0.896168108776267)" "(0.5035873688342987, 0.8677379480840544)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1065,7 +1066,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.3649017943276299, 0.9074074074074074)" "(0.43131878033832266, 0.8851851851851852)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1083,7 +1084,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.18619558424660096, 0.9765142150803461)" "(0.22238253315332793, 0.9678615574783683)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1092,7 +1093,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.16293201995668588, 0.9888888888888889)" "(0.18925935278336206, 0.9814814814814815)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1110,7 +1111,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.09108264647580784, 0.9962917181705809)" "(0.10367853983509158, 0.9913473423980222)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1119,7 +1120,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.08985773311858927, 0.9962962962962963)" "(0.09969225936327819, 0.9962962962962963)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1137,7 +1138,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.053487053708540566, 0.9987639060568603)" "(0.0588170926504491, 0.9987639060568603)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1146,7 +1147,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.05794332528279887, 1.0)" "(0.06267384567332489, 1.0)"
] ]
}, },
"metadata": {}, "metadata": {},
@ -1189,7 +1190,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(0.16834938257537793, 0.9428172942817294)" "(0.17201613383874234, 0.9414225941422594)"
] ]
}, },
"execution_count": 50, "execution_count": 50,