forked from filipg/aitech-eks-pub
fix nn
This commit is contained in:
parent
3d70d8a7ec
commit
54ab26b5f9
@ -777,6 +777,7 @@
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc1(x)\n",
|
||||
" x = torch.relu(x)\n",
|
||||
" x = self.fc2(x)\n",
|
||||
" x = torch.sigmoid(x)\n",
|
||||
" return x"
|
||||
|
@ -402,11 +402,11 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[0.4978],\n",
|
||||
" [0.5009],\n",
|
||||
" [0.4998],\n",
|
||||
" [0.4990],\n",
|
||||
" [0.5018]], grad_fn=<SigmoidBackward>)"
|
||||
"tensor([[0.4989],\n",
|
||||
" [0.4985],\n",
|
||||
" [0.4970],\n",
|
||||
" [0.4968],\n",
|
||||
" [0.5007]], grad_fn=<SigmoidBackward>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
@ -449,10 +449,10 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Parameter containing:\n",
|
||||
" tensor([[-0.0059, 0.0035, 0.0021, ..., -0.0042, -0.0057, -0.0049]],\n",
|
||||
" tensor([[ 0.0006, -0.0076, 0.0002, ..., 0.0051, 0.0034, -0.0004]],\n",
|
||||
" requires_grad=True),\n",
|
||||
" Parameter containing:\n",
|
||||
" tensor([-0.0023], requires_grad=True)]"
|
||||
" tensor([-0.0099], requires_grad=True)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
@ -556,10 +556,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[0.5667],\n",
|
||||
" [0.5802],\n",
|
||||
" [0.5757],\n",
|
||||
" [0.5670]], grad_fn=<SigmoidBackward>)"
|
||||
"tensor([[0.5657],\n",
|
||||
" [0.5827],\n",
|
||||
" [0.5727],\n",
|
||||
" [0.5672]], grad_fn=<SigmoidBackward>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 28,
|
||||
@ -604,7 +604,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"452"
|
||||
"453"
|
||||
]
|
||||
},
|
||||
"execution_count": 30,
|
||||
@ -645,7 +645,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"accuracy: 0.5587144622991347\n"
|
||||
"accuracy: 0.5599505562422744\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -664,7 +664,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"BCE loss: 0.6745463597170355\n"
|
||||
"BCE loss: 0.6745760098965412\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -717,7 +717,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.6443227143826974, 0.622991347342398)"
|
||||
"(0.6443268107837445, 0.6254635352286774)"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
@ -737,7 +737,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.6369243131743537, 0.6037037037037037)"
|
||||
"(0.6371536641209213, 0.6074074074074074)"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
@ -757,7 +757,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.6323775731785694, 0.6499302649930265)"
|
||||
"(0.6322633745447529, 0.6485355648535565)"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
@ -785,10 +785,10 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Parameter containing:\n",
|
||||
" tensor([[ 0.0314, -0.0375, 0.0131, ..., -0.0057, -0.0008, -0.0089]],\n",
|
||||
" tensor([[ 0.0379, -0.0485, 0.0113, ..., 0.0035, 0.0083, -0.0044]],\n",
|
||||
" requires_grad=True),\n",
|
||||
" Parameter containing:\n",
|
||||
" tensor([0.0563], requires_grad=True)]"
|
||||
" tensor([0.0556], requires_grad=True)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
@ -808,7 +808,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([ 0.0314, -0.0375, 0.0131, ..., -0.0057, -0.0008, -0.0089],\n",
|
||||
"tensor([ 0.0379, -0.0485, 0.0113, ..., 0.0035, 0.0083, -0.0044],\n",
|
||||
" grad_fn=<SelectBackward>)"
|
||||
]
|
||||
},
|
||||
@ -830,11 +830,11 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.return_types.topk(\n",
|
||||
"values=tensor([0.3753, 0.2305, 0.2007, 0.2006, 0.1993, 0.1952, 0.1930, 0.1898, 0.1831,\n",
|
||||
" 0.1731, 0.1649, 0.1647, 0.1543, 0.1320, 0.1314, 0.1303, 0.1296, 0.1261,\n",
|
||||
" 0.1245, 0.1243], grad_fn=<TopkBackward>),\n",
|
||||
"indices=tensor([8942, 6336, 1852, 9056, 1865, 4039, 7820, 5002, 8208, 1857, 9709, 803,\n",
|
||||
" 1046, 130, 4306, 6481, 4370, 4259, 4285, 1855]))"
|
||||
"values=tensor([0.3804, 0.2315, 0.2033, 0.2026, 0.2014, 0.1993, 0.1942, 0.1890, 0.1868,\n",
|
||||
" 0.1818, 0.1727, 0.1542, 0.1474, 0.1458, 0.1360, 0.1359, 0.1260, 0.1204,\n",
|
||||
" 0.1184, 0.1174], grad_fn=<TopkBackward>),\n",
|
||||
"indices=tensor([8942, 6336, 1865, 1852, 8208, 9056, 7820, 4039, 5002, 1857, 9709, 803,\n",
|
||||
" 130, 1046, 4370, 4259, 4306, 1855, 4285, 6481]))"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
@ -857,24 +857,24 @@
|
||||
"text": [
|
||||
"the\n",
|
||||
"of\n",
|
||||
"christ\n",
|
||||
"to\n",
|
||||
"church\n",
|
||||
"god\n",
|
||||
"rutgers\n",
|
||||
"jesus\n",
|
||||
"christ\n",
|
||||
"sin\n",
|
||||
"to\n",
|
||||
"rutgers\n",
|
||||
"god\n",
|
||||
"jesus\n",
|
||||
"christians\n",
|
||||
"we\n",
|
||||
"and\n",
|
||||
"athos\n",
|
||||
"1993\n",
|
||||
"hell\n",
|
||||
"our\n",
|
||||
"athos\n",
|
||||
"his\n",
|
||||
"he\n",
|
||||
"hell\n",
|
||||
"christian\n",
|
||||
"heaven\n",
|
||||
"christian\n"
|
||||
"our\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -892,11 +892,11 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.return_types.topk(\n",
|
||||
"values=tensor([-0.3478, -0.2578, -0.2455, -0.2347, -0.2330, -0.2265, -0.2205, -0.2050,\n",
|
||||
" -0.2044, -0.1979, -0.1876, -0.1790, -0.1747, -0.1745, -0.1734, -0.1647,\n",
|
||||
" -0.1639, -0.1617, -0.1601, -0.1592], grad_fn=<TopkBackward>),\n",
|
||||
"indices=tensor([5119, 8096, 5420, 4436, 6194, 1627, 6901, 5946, 9970, 3116, 1036, 9906,\n",
|
||||
" 5654, 8329, 7869, 1039, 1991, 4926, 5035, 4925]))"
|
||||
"values=tensor([-0.3464, -0.2578, -0.2372, -0.2307, -0.2300, -0.2259, -0.2227, -0.2107,\n",
|
||||
" -0.2054, -0.1949, -0.1919, -0.1767, -0.1767, -0.1749, -0.1747, -0.1739,\n",
|
||||
" -0.1715, -0.1633, -0.1567, -0.1562], grad_fn=<TopkBackward>),\n",
|
||||
"indices=tensor([5119, 8096, 5420, 1627, 6194, 6901, 4436, 9970, 5946, 3116, 1036, 9906,\n",
|
||||
" 7869, 5654, 1991, 8329, 4925, 4926, 6373, 1039]))"
|
||||
]
|
||||
},
|
||||
"execution_count": 42,
|
||||
@ -922,23 +922,23 @@
|
||||
"keith\n",
|
||||
"sgi\n",
|
||||
"livesey\n",
|
||||
"host\n",
|
||||
"nntp\n",
|
||||
"caltech\n",
|
||||
"nntp\n",
|
||||
"posting\n",
|
||||
"morality\n",
|
||||
"host\n",
|
||||
"you\n",
|
||||
"morality\n",
|
||||
"edu\n",
|
||||
"atheism\n",
|
||||
"wpd\n",
|
||||
"mathew\n",
|
||||
"solntze\n",
|
||||
"sandvik\n",
|
||||
"atheists\n",
|
||||
"mathew\n",
|
||||
"com\n",
|
||||
"solntze\n",
|
||||
"islam\n",
|
||||
"islamic\n",
|
||||
"jon\n",
|
||||
"islam\n"
|
||||
"okcforum\n",
|
||||
"atheists\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -969,6 +969,7 @@
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc1(x)\n",
|
||||
" x = torch.relu(x)\n",
|
||||
" x = self.fc2(x)\n",
|
||||
" x = torch.sigmoid(x)\n",
|
||||
" return x"
|
||||
@ -1029,7 +1030,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.6605833534551934, 0.5908529048207664)"
|
||||
"(0.6734723948651398, 0.5636588380716935)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1038,7 +1039,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.6379233609747004, 0.6481481481481481)"
|
||||
"(0.6606645694485417, 0.5777777777777777)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1056,7 +1057,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.4341224195120214, 0.896168108776267)"
|
||||
"(0.5035873688342987, 0.8677379480840544)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1065,7 +1066,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.3649017943276299, 0.9074074074074074)"
|
||||
"(0.43131878033832266, 0.8851851851851852)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1083,7 +1084,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.18619558424660096, 0.9765142150803461)"
|
||||
"(0.22238253315332793, 0.9678615574783683)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1092,7 +1093,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.16293201995668588, 0.9888888888888889)"
|
||||
"(0.18925935278336206, 0.9814814814814815)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1110,7 +1111,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.09108264647580784, 0.9962917181705809)"
|
||||
"(0.10367853983509158, 0.9913473423980222)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1119,7 +1120,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.08985773311858927, 0.9962962962962963)"
|
||||
"(0.09969225936327819, 0.9962962962962963)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1137,7 +1138,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.053487053708540566, 0.9987639060568603)"
|
||||
"(0.0588170926504491, 0.9987639060568603)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1146,7 +1147,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.05794332528279887, 1.0)"
|
||||
"(0.06267384567332489, 1.0)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -1189,7 +1190,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0.16834938257537793, 0.9428172942817294)"
|
||||
"(0.17201613383874234, 0.9414225941422594)"
|
||||
]
|
||||
},
|
||||
"execution_count": 50,
|
||||
|
Loading…
Reference in New Issue
Block a user