revert
s444354-training/pipeline/head There was a failure building this commit Details

This commit is contained in:
Adrian Charkiewicz 2022-05-07 21:21:03 +02:00
parent 54d71d631e
commit 6d0ecd01c4
2 changed files with 4 additions and 21 deletions

View File

@ -1 +1 @@
evaluation.py pytorch — kopia.py

View File

@ -18,21 +18,10 @@ from torch.utils.data import DataLoader, TensorDataset, random_split
import random import random
import os import os
import sys import sys
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
# In[2]: # In[2]:
ex = Experiment(save_git_info=False)
ex.observers.append(FileStorageObserver('runs'))
@ex.config
def config():
epochs = 1500
dataframe_raw = pd.read_csv("winequality-red.csv") dataframe_raw = pd.read_csv("winequality-red.csv")
dataframe_raw.head() dataframe_raw.head()
@ -142,8 +131,7 @@ def evaluate(model, val_loader):
outputs = [model.validation_step(batch) for batch in val_loader] outputs = [model.validation_step(batch) for batch in val_loader]
return model.validation_epoch_end(outputs) return model.validation_epoch_end(outputs)
@ex.capture def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
def fit(lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, epochs, _run):
history = [] history = []
optimizer = opt_func(model.parameters(), lr) optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs): for epoch in range(epochs):
@ -162,8 +150,8 @@ def fit(lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, epochs, _
#epochs = int(sys.argv[1]) #epochs = int(sys.argv[1])
lr = 1e-6
history5 = fit(epochs, lr, model, train_loader, val_loader)
# In[27]: # In[27]:
@ -195,10 +183,5 @@ with open("result.txt", "w+") as file:
input_, target = val_ds[i] input_, target = val_ds[i]
file.write(str(predict_single(input_, target, model))) file.write(str(predict_single(input_, target, model)))
@ex.main
def main():
lr = 1e-6
history5 = fit(lr, model, train_loader, val_loader, epochs)
ex.run()