Zaktualizuj 'ml_pytroch_sacred.py'

This commit is contained in:
Sebastian Wałęsa 2022-05-08 20:38:37 +02:00
parent 7445be7014
commit 023b0756a0

View File

@ -100,21 +100,23 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
history.append(result)
return history
input_size = len(input_cols)
output_size = len(output_cols)
model=Model_xPosition()
lr = 1e-5
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
def predict_single(input, target, model):
inputs = input.unsqueeze(0)
predictions = model(inputs)
prediction = predictions[0].detach()
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
input_size = len(input_cols)
output_size = len(output_cols)
model=Model_xPosition()
lr = 1e-5
@ex.automain
def my_main(epochs):
def my_main(epochs):
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
for i in random.sample(range(0, len(val_ds)), 10):
input_, target = val_ds[i]
print(predict_single(input_, target, model),end="")