This commit is contained in:
parent
87fe38d3d1
commit
0e7a4840c9
@ -13,6 +13,10 @@ RUN pip3 install torch
|
|||||||
RUN pip3 install seaborn
|
RUN pip3 install seaborn
|
||||||
RUN pip3 install torchvision
|
RUN pip3 install torchvision
|
||||||
RUN pip3 install sklearn
|
RUN pip3 install sklearn
|
||||||
|
RUN pip3 install sacred
|
||||||
|
RUN pip3 install pymongo
|
||||||
|
RUN pip3 install GitPython
|
||||||
|
|
||||||
# RUN python3 -m pip install kaggle
|
# RUN python3 -m pip install kaggle
|
||||||
RUN python3 -m pip install pandas
|
RUN python3 -m pip install pandas
|
||||||
RUN pip3 install matplotlib
|
RUN pip3 install matplotlib
|
||||||
|
@ -18,10 +18,21 @@ from torch.utils.data import DataLoader, TensorDataset, random_split
|
|||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from sacred import Experiment
|
||||||
|
from sacred.observers import FileStorageObserver
|
||||||
|
from sacred.observers import MongoObserver
|
||||||
|
|
||||||
|
|
||||||
# In[2]:
|
# In[2]:
|
||||||
|
|
||||||
|
ex = Experiment(save_git_info=False)
|
||||||
|
ex.observers.append(FileStorageObserver('runs'))
|
||||||
|
@ex.config
|
||||||
|
def config():
|
||||||
|
epochs = 1500
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
dataframe_raw = pd.read_csv("winequality-red.csv")
|
dataframe_raw = pd.read_csv("winequality-red.csv")
|
||||||
dataframe_raw.head()
|
dataframe_raw.head()
|
||||||
@ -131,7 +142,8 @@ def evaluate(model, val_loader):
|
|||||||
outputs = [model.validation_step(batch) for batch in val_loader]
|
outputs = [model.validation_step(batch) for batch in val_loader]
|
||||||
return model.validation_epoch_end(outputs)
|
return model.validation_epoch_end(outputs)
|
||||||
|
|
||||||
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
@ex.capture
|
||||||
|
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, _run):
|
||||||
history = []
|
history = []
|
||||||
optimizer = opt_func(model.parameters(), lr)
|
optimizer = opt_func(model.parameters(), lr)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
@ -149,9 +161,9 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
|||||||
# In[12]:
|
# In[12]:
|
||||||
|
|
||||||
|
|
||||||
epochs = int(sys.argv[1])
|
#epochs = int(sys.argv[1])
|
||||||
lr = 1e-6
|
|
||||||
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
|
||||||
|
|
||||||
|
|
||||||
# In[27]:
|
# In[27]:
|
||||||
@ -183,5 +195,10 @@ with open("result.txt", "w+") as file:
|
|||||||
input_, target = val_ds[i]
|
input_, target = val_ds[i]
|
||||||
file.write(str(predict_single(input_, target, model)))
|
file.write(str(predict_single(input_, target, model)))
|
||||||
|
|
||||||
|
@ex.main
|
||||||
|
def main():
|
||||||
|
lr = 1e-6
|
||||||
|
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
||||||
|
|
||||||
|
ex.run()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user