forked from filipg/aitech-eks-pub
minor fix prev
This commit is contained in:
parent
54ab26b5f9
commit
c9825d8d60
@ -507,11 +507,7 @@
|
|||||||
" acc_score += torch.sum((Y_predictions > 0.5) == Y).item()\n",
|
" acc_score += torch.sum((Y_predictions > 0.5) == Y).item()\n",
|
||||||
" items_total += Y.shape[0] \n",
|
" items_total += Y.shape[0] \n",
|
||||||
"\n",
|
"\n",
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" loss = criterion(Y_predictions, Y)\n",
|
" loss = criterion(Y_predictions, Y)\n",
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" loss_score += loss.item() * Y.shape[0] \n",
|
" loss_score += loss.item() * Y.shape[0] \n",
|
||||||
" return (loss_score / items_total), (acc_score / items_total)"
|
" return (loss_score / items_total), (acc_score / items_total)"
|
||||||
|
@ -699,11 +699,7 @@
|
|||||||
" acc_score += torch.sum((Y_predictions > 0.5) == Y).item()\n",
|
" acc_score += torch.sum((Y_predictions > 0.5) == Y).item()\n",
|
||||||
" items_total += Y.shape[0] \n",
|
" items_total += Y.shape[0] \n",
|
||||||
"\n",
|
"\n",
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" loss = criterion(Y_predictions, Y)\n",
|
" loss = criterion(Y_predictions, Y)\n",
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" loss_score += loss.item() * Y.shape[0] \n",
|
" loss_score += loss.item() * Y.shape[0] \n",
|
||||||
" return (loss_score / items_total), (acc_score / items_total)"
|
" return (loss_score / items_total), (acc_score / items_total)"
|
||||||
|
Loading…
Reference in New Issue
Block a user