Scared
This commit is contained in:
parent
61043e2273
commit
8dbf0e8699
@ -28,7 +28,7 @@ class LogisticRegressionModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@ex.capture
|
@ex.capture
|
||||||
def readAndtrain(epchos, batch_size, _run):
|
def readAndtrain(epochs, batch_size, _run):
|
||||||
train = pd.read_csv("train.csv")
|
train = pd.read_csv("train.csv")
|
||||||
test = pd.read_csv("test.csv")
|
test = pd.read_csv("test.csv")
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ def readAndtrain(epchos, batch_size, _run):
|
|||||||
input_dim = 11
|
input_dim = 11
|
||||||
output_dim = 1
|
output_dim = 1
|
||||||
|
|
||||||
_run.info("Batch: " + str(batch_size) + " epoch: " + epchos)
|
_run.info("Batch: " + str(batch_size) + " epoch: " + epochs)
|
||||||
|
|
||||||
model = LogisticRegressionModel(input_dim, output_dim)
|
model = LogisticRegressionModel(input_dim, output_dim)
|
||||||
model.load_state_dict(torch.load('DEATH_EVENT.pth'))
|
model.load_state_dict(torch.load('DEATH_EVENT.pth'))
|
||||||
@ -53,7 +53,7 @@ def readAndtrain(epchos, batch_size, _run):
|
|||||||
criterion = torch.nn.BCELoss(reduction='mean')
|
criterion = torch.nn.BCELoss(reduction='mean')
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
for epoch in range(epchos):
|
for epoch in range(epochs):
|
||||||
# print ("Epoch #",epoch)
|
# print ("Epoch #",epoch)
|
||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
Reference in New Issue
Block a user