minor fix prev
This commit is contained in:
parent
54ab26b5f9
commit
c9825d8d60
@ -507,11 +507,7 @@
|
||||
" acc_score += torch.sum((Y_predictions > 0.5) == Y).item()\n",
|
||||
" items_total += Y.shape[0] \n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss = criterion(Y_predictions, Y)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" loss_score += loss.item() * Y.shape[0] \n",
|
||||
" return (loss_score / items_total), (acc_score / items_total)"
|
||||
|
@ -699,11 +699,7 @@
|
||||
" acc_score += torch.sum((Y_predictions > 0.5) == Y).item()\n",
|
||||
" items_total += Y.shape[0] \n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss = criterion(Y_predictions, Y)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" loss_score += loss.item() * Y.shape[0] \n",
|
||||
" return (loss_score / items_total), (acc_score / items_total)"
|
||||
|
Loading…
Reference in New Issue
Block a user