This commit is contained in:
parent
54d71d631e
commit
6d0ecd01c4
@ -1 +1 @@
|
||||
evaluation.py
|
||||
pytorch — kopia.py
|
@ -18,21 +18,10 @@ from torch.utils.data import DataLoader, TensorDataset, random_split
|
||||
import random
|
||||
import os
|
||||
import sys
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver
|
||||
from sacred.observers import MongoObserver
|
||||
|
||||
|
||||
# In[2]:
|
||||
|
||||
ex = Experiment(save_git_info=False)
|
||||
ex.observers.append(FileStorageObserver('runs'))
|
||||
@ex.config
|
||||
def config():
|
||||
epochs = 1500
|
||||
|
||||
|
||||
|
||||
|
||||
dataframe_raw = pd.read_csv("winequality-red.csv")
|
||||
dataframe_raw.head()
|
||||
@ -142,8 +131,7 @@ def evaluate(model, val_loader):
|
||||
outputs = [model.validation_step(batch) for batch in val_loader]
|
||||
return model.validation_epoch_end(outputs)
|
||||
|
||||
@ex.capture
|
||||
def fit(lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, epochs, _run):
|
||||
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
|
||||
history = []
|
||||
optimizer = opt_func(model.parameters(), lr)
|
||||
for epoch in range(epochs):
|
||||
@ -162,8 +150,8 @@ def fit(lr, model, train_loader, val_loader, opt_func=torch.optim.SGD, epochs, _
|
||||
|
||||
|
||||
#epochs = int(sys.argv[1])
|
||||
|
||||
|
||||
lr = 1e-6
|
||||
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
||||
|
||||
|
||||
# In[27]:
|
||||
@ -195,10 +183,5 @@ with open("result.txt", "w+") as file:
|
||||
input_, target = val_ds[i]
|
||||
file.write(str(predict_single(input_, target, model)))
|
||||
|
||||
@ex.main
|
||||
def main():
|
||||
lr = 1e-6
|
||||
history5 = fit(lr, model, train_loader, val_loader, epochs)
|
||||
|
||||
ex.run()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user