Scared
All checks were successful
s434732-evaluation/pipeline/head This commit looks good
s434732-training/pipeline/head This commit looks good

This commit is contained in:
s434732 2021-05-15 17:24:42 +02:00
parent fc517c0b15
commit 5d8988cc11

View File

@ -45,8 +45,8 @@ def readAndtrain(epochs, batch_size, _run):
input_dim = 11
output_dim = 1
_run.info("Batch", str(batch_size))
_run.info("epoch", str(epochs))
_run.log_scalar("Batch", str(batch_size))
_run.log_scalar("epoch", str(epochs))
model = LogisticRegressionModel(input_dim, output_dim)
model.load_state_dict(torch.load('DEATH_EVENT.pth'))
@ -67,16 +67,16 @@ def readAndtrain(epochs, batch_size, _run):
loss.backward()
optimizer.step()
_run.info("Lost", str(loss.item()))
_run.log_scalar("Lost", str(loss.item()))
torch.save(model.state_dict(), 'DEATH_EVENT.pth')
prediction= model(xTest)
_run.info("accuracy_score", accuracy_score(yTest, np.argmax(prediction.detach().numpy(), axis=1)))
_run.info("F1", f1_score(yTest, np.argmax(prediction.detach().numpy(), axis=1), average=None))
_run.log_scalar("accuracy_score", accuracy_score(yTest, np.argmax(prediction.detach().numpy(), axis=1)))
# _run.log_scalar("F1", str(f1_score(yTest, np.argmax(prediction.detach().numpy(), axis=1), average=None)))
print("accuracy_score", accuracy_score(yTest, np.argmax(prediction.detach().numpy(), axis=1)))
print("F1", f1_score(yTest, np.argmax(prediction.detach().numpy(), axis=1), average=None))
# print("F1", f1_score(yTest, np.argmax(prediction.detach().numpy(), axis=1), average=None))
@ex.automain
def my_main(epochs, batch_size, _run):