This commit is contained in:
s434766 2021-05-11 22:16:03 +02:00
parent 34e2139417
commit b3d4c74ecb
2 changed files with 46 additions and 47 deletions

View File

@ -2,7 +2,6 @@ FROM ubuntu:latest
RUN apt-get update && apt-get install -y python3-pip && pip3 install setuptools && pip3 install numpy && pip3 install pandas && pip3 install wget && pip3 install scikit-learn && pip3 install matplotlib && rm -rf /var/lib/apt/lists/* RUN apt-get update && apt-get install -y python3-pip && pip3 install setuptools && pip3 install numpy && pip3 install pandas && pip3 install wget && pip3 install scikit-learn && pip3 install matplotlib && rm -rf /var/lib/apt/lists/*
RUN pip3 install torch torchvision torchaudio RUN pip3 install torch torchvision torchaudio
RUN pip3 install sacred
WORKDIR /app WORKDIR /app
COPY ./../create.py ./ COPY ./../create.py ./

View File

@ -13,10 +13,10 @@ from sacred import Experiment
from sacred.observers import FileStorageObserver from sacred.observers import FileStorageObserver
np.set_printoptions(suppress=False) np.set_printoptions(suppress=False)
ex = Experiment("stroke-pytorch", interactive=True) # ex = Experiment("stroke-pytorch", interactive=True)
ex.observers.append(FileStorageObserver('ium_s434766O_files')) # ex.observers.append(FileStorageObserver('ium_s434766O_files'))
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017', # ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
db_name='sacred')) # db_name='sacred'))
class LogisticRegressionModel(nn.Module): class LogisticRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim): def __init__(self, input_dim, output_dim):
super(LogisticRegressionModel, self).__init__() super(LogisticRegressionModel, self).__init__()
@ -26,54 +26,54 @@ class LogisticRegressionModel(nn.Module):
out = self.linear(x) out = self.linear(x)
return self.sigmoid(out) return self.sigmoid(out)
@ex.main # @ex.main
def my_main(_log): # def my_main(_log):
data_train = pd.read_csv("data_train.csv") data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv") data_test = pd.read_csv("data_test.csv")
data_val = pd.read_csv("data_val.csv") data_val = pd.read_csv("data_val.csv")
FEATURES = ['age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi'] FEATURES = ['age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
x_train = data_train[FEATURES].astype(np.float32) x_train = data_train[FEATURES].astype(np.float32)
y_train = data_train['stroke'].astype(np.float32) y_train = data_train['stroke'].astype(np.float32)
x_test = data_test[FEATURES].astype(np.float32) x_test = data_test[FEATURES].astype(np.float32)
y_test = data_test['stroke'].astype(np.float32) y_test = data_test['stroke'].astype(np.float32)
fTrain = torch.from_numpy(x_train.values) fTrain = torch.from_numpy(x_train.values)
tTrain = torch.from_numpy(y_train.values.reshape(2945,1)) tTrain = torch.from_numpy(y_train.values.reshape(2945,1))
fTest= torch.from_numpy(x_test.values) fTest= torch.from_numpy(x_test.values)
tTest = torch.from_numpy(y_test.values) tTest = torch.from_numpy(y_test.values)
batch_size = int(sys.argv[1]) if len(sys.argv) > 1 else 16 batch_size = int(sys.argv[1]) if len(sys.argv) > 1 else 16
num_epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 5 num_epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 5
learning_rate = 0.001 learning_rate = 0.001
input_dim = 6 input_dim = 6
output_dim = 1 output_dim = 1
info_params = "Batch size = " + str(batch_size) + " Epochs = " + str(num_epochs) info_params = "Batch size = " + str(batch_size) + " Epochs = " + str(num_epochs)
_log.info(info_params) # _log.info(info_params)
model = LogisticRegressionModel(input_dim, output_dim) model = LogisticRegressionModel(input_dim, output_dim)
criterion = torch.nn.BCELoss(reduction='mean') criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
for epoch in range(num_epochs): for epoch in range(num_epochs):
# print ("Epoch #",epoch) # print ("Epoch #",epoch)
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
# Forward pass # Forward pass
y_pred = model(fTrain) y_pred = model(fTrain)
# Compute Loss # Compute Loss
loss = criterion(y_pred, tTrain) loss = criterion(y_pred, tTrain)
# print(loss.item()) # print(loss.item())
# Backward pass # Backward pass
loss.backward() loss.backward()
optimizer.step() optimizer.step()
info_loss = "Last loss = " + str(loss.item()) info_loss = "Last loss = " + str(loss.item())
_log.info(info_loss) # _log.info(info_loss)
y_pred = model(fTest) y_pred = model(fTest)
# print("predicted Y value: ", y_pred.data) print("predicted Y value: ", y_pred.data)
torch.save(model.state_dict(), 'stroke.pth') torch.save(model.state_dict(), 'stroke.pth')
ex.run() # ex.run()