sacred5
Some checks failed
s434766-training/pipeline/head There was a failure building this commit

This commit is contained in:
s434766 2021-05-13 23:20:15 +02:00
parent 2fc7824953
commit c0668627a8
2 changed files with 6 additions and 6 deletions

View File

@ -29,7 +29,7 @@ class LogisticRegressionModel(nn.Module):
return self.sigmoid(out)
@ex.capture
def train(num_epochs, batch_size, learning_rate, _run):
def train(num_epochs, batch_size, learning_rate, _log):
data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv")
FEATURES = ['age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
@ -50,7 +50,7 @@ def train(num_epochs, batch_size, learning_rate, _run):
input_dim = 6
output_dim = 1
info_params = "Batch size = " + str(batch_size) + " Epochs = " + str(num_epochs)
_run.info(info_params)
_log.info(info_params)
model = LogisticRegressionModel(input_dim, output_dim)
criterion = torch.nn.BCELoss(reduction='mean')
@ -70,7 +70,7 @@ def train(num_epochs, batch_size, learning_rate, _run):
optimizer.step()
info_loss = "Last loss = " + str(loss.item())
_run.info(info_loss)
_log.info(info_loss)
y_pred = model(fTest)
print("predicted Y value: ", y_pred.data)

View File

@ -30,7 +30,7 @@ class LogisticRegressionModel(nn.Module):
return self.sigmoid(out)
@ex.capture
def train(num_epochs, batch_size, learning_rate, _run):
def train(num_epochs, batch_size, learning_rate, _log):
data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv")
FEATURES = ['age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
@ -51,7 +51,7 @@ def train(num_epochs, batch_size, learning_rate, _run):
input_dim = 6
output_dim = 1
info_params = "Batch size = " + str(batch_size) + " Epochs = " + str(num_epochs)
_run.info(info_params)
_log.info(info_params)
model = LogisticRegressionModel(input_dim, output_dim)
criterion = torch.nn.BCELoss(reduction='mean')
@ -71,7 +71,7 @@ def train(num_epochs, batch_size, learning_rate, _run):
optimizer.step()
info_loss = "Last loss = " + str(loss.item())
_run.info(info_loss)
_log.info(info_loss)
y_pred = model(fTest)
print("predicted Y value: ", y_pred.data)