params update 2
This commit is contained in:
parent
866dd37a22
commit
d1376a046a
@ -76,7 +76,7 @@ def prepare_labels_features(dataset):
|
||||
|
||||
return lab, feat
|
||||
|
||||
@ex.main
|
||||
@ex.automain
|
||||
def my_main(epochs, _run):
|
||||
# Prepare dataset
|
||||
print("Loading dataset...")
|
||||
@ -96,17 +96,17 @@ def my_main(epochs, _run):
|
||||
|
||||
# number of epochs is parametrized
|
||||
try:
|
||||
epochs = int(epochs)
|
||||
epochs_n = int(epochs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Setting default epochs value to 1000.")
|
||||
epochs = 10
|
||||
epochs_n = 10
|
||||
|
||||
print(f"Number of epochs: {epochs}")
|
||||
print(f"Number of epochs: {epochs_n}")
|
||||
|
||||
print("Starting model training...")
|
||||
x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()
|
||||
for epoch in range(1, epochs + 1):
|
||||
for epoch in range(1, epochs_n + 1):
|
||||
print("Epoch #", epoch)
|
||||
y_pred = model(x_train)
|
||||
loss = loss_fn(y_pred, y_train)
|
||||
@ -138,5 +138,5 @@ def my_main(epochs, _run):
|
||||
pd_predictions = pd.DataFrame(pred)
|
||||
pd_predictions.to_csv("./prediction_results.csv")
|
||||
|
||||
ex.run()
|
||||
# ex.run()
|
||||
ex.add_artifact("CarPrices_pytorch_model.pkl")
|
Loading…
Reference in New Issue
Block a user