sacred
Some checks failed
s444354-training/pipeline/head There was a failure building this commit

This commit is contained in:
Adrian Charkiewicz 2022-05-07 21:07:01 +02:00
parent 87fe38d3d1
commit 0e7a4840c9
2 changed files with 25 additions and 4 deletions

View File

@ -13,6 +13,10 @@ RUN pip3 install torch
RUN pip3 install seaborn RUN pip3 install seaborn
RUN pip3 install torchvision RUN pip3 install torchvision
RUN pip3 install sklearn RUN pip3 install sklearn
RUN pip3 install sacred
RUN pip3 install pymongo
RUN pip3 install GitPython
# RUN python3 -m pip install kaggle # RUN python3 -m pip install kaggle
RUN python3 -m pip install pandas RUN python3 -m pip install pandas
RUN pip3 install matplotlib RUN pip3 install matplotlib

View File

@ -18,10 +18,21 @@ from torch.utils.data import DataLoader, TensorDataset, random_split
import random import random
import os import os
import sys import sys
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
# In[2]: # In[2]:
ex = Experiment(save_git_info=False)
ex.observers.append(FileStorageObserver('runs'))
@ex.config
def config():
epochs = 1500
dataframe_raw = pd.read_csv("winequality-red.csv") dataframe_raw = pd.read_csv("winequality-red.csv")
dataframe_raw.head() dataframe_raw.head()
@ -131,7 +142,8 @@ def evaluate(model, val_loader):
outputs = [model.validation_step(batch) for batch in val_loader] outputs = [model.validation_step(batch) for batch in val_loader]
return model.validation_epoch_end(outputs) return model.validation_epoch_end(outputs)
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): @ex.capture
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, _run):
history = [] history = []
optimizer = opt_func(model.parameters(), lr) optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs): for epoch in range(epochs):
@ -149,9 +161,9 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
# In[12]: # In[12]:
epochs = int(sys.argv[1]) #epochs = int(sys.argv[1])
lr = 1e-6
history5 = fit(epochs, lr, model, train_loader, val_loader)
# In[27]: # In[27]:
@ -183,5 +195,10 @@ with open("result.txt", "w+") as file:
input_, target = val_ds[i] input_, target = val_ds[i]
file.write(str(predict_single(input_, target, model))) file.write(str(predict_single(input_, target, model)))
@ex.main
def main():
lr = 1e-6
history5 = fit(epochs, lr, model, train_loader, val_loader)
ex.run()