Scared
This commit is contained in:
parent
61043e2273
commit
8dbf0e8699
@ -28,7 +28,7 @@ class LogisticRegressionModel(nn.Module):
|
||||
|
||||
|
||||
@ex.capture
|
||||
def readAndtrain(epchos, batch_size, _run):
|
||||
def readAndtrain(epochs, batch_size, _run):
|
||||
train = pd.read_csv("train.csv")
|
||||
test = pd.read_csv("test.csv")
|
||||
|
||||
@ -45,7 +45,7 @@ def readAndtrain(epchos, batch_size, _run):
|
||||
input_dim = 11
|
||||
output_dim = 1
|
||||
|
||||
_run.info("Batch: " + str(batch_size) + " epoch: " + epchos)
|
||||
_run.info("Batch: " + str(batch_size) + " epoch: " + epochs)
|
||||
|
||||
model = LogisticRegressionModel(input_dim, output_dim)
|
||||
model.load_state_dict(torch.load('DEATH_EVENT.pth'))
|
||||
@ -53,7 +53,7 @@ def readAndtrain(epchos, batch_size, _run):
|
||||
criterion = torch.nn.BCELoss(reduction='mean')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
||||
|
||||
for epoch in range(epchos):
|
||||
for epoch in range(epochs):
|
||||
# print ("Epoch #",epoch)
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
Loading…
Reference in New Issue
Block a user