params update 2
This commit is contained in:
parent
866dd37a22
commit
d1376a046a
@ -76,7 +76,7 @@ def prepare_labels_features(dataset):
|
|||||||
|
|
||||||
return lab, feat
|
return lab, feat
|
||||||
|
|
||||||
@ex.main
|
@ex.automain
|
||||||
def my_main(epochs, _run):
|
def my_main(epochs, _run):
|
||||||
# Prepare dataset
|
# Prepare dataset
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
@ -96,17 +96,17 @@ def my_main(epochs, _run):
|
|||||||
|
|
||||||
# number of epochs is parametrized
|
# number of epochs is parametrized
|
||||||
try:
|
try:
|
||||||
epochs = int(epochs)
|
epochs_n = int(epochs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("Setting default epochs value to 1000.")
|
print("Setting default epochs value to 1000.")
|
||||||
epochs = 10
|
epochs_n = 10
|
||||||
|
|
||||||
print(f"Number of epochs: {epochs}")
|
print(f"Number of epochs: {epochs_n}")
|
||||||
|
|
||||||
print("Starting model training...")
|
print("Starting model training...")
|
||||||
x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()
|
x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()
|
||||||
for epoch in range(1, epochs + 1):
|
for epoch in range(1, epochs_n + 1):
|
||||||
print("Epoch #", epoch)
|
print("Epoch #", epoch)
|
||||||
y_pred = model(x_train)
|
y_pred = model(x_train)
|
||||||
loss = loss_fn(y_pred, y_train)
|
loss = loss_fn(y_pred, y_train)
|
||||||
@ -138,5 +138,5 @@ def my_main(epochs, _run):
|
|||||||
pd_predictions = pd.DataFrame(pred)
|
pd_predictions = pd.DataFrame(pred)
|
||||||
pd_predictions.to_csv("./prediction_results.csv")
|
pd_predictions.to_csv("./prediction_results.csv")
|
||||||
|
|
||||||
ex.run()
|
# ex.run()
|
||||||
ex.add_artifact("CarPrices_pytorch_model.pkl")
|
ex.add_artifact("CarPrices_pytorch_model.pkl")
|
Loading…
Reference in New Issue
Block a user