Zaktualizuj 'ml_pytroch_sacred.py'

This commit is contained in:
Sebastian Wałęsa 2022-05-08 19:58:27 +02:00
parent 6ec74ae14b
commit 8d11e06d9f

View File

@ -100,21 +100,22 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
history.append(result) history.append(result)
return history return history
@ex.automain
def my_main(epochs): input_size = len(input_cols)
input_size = len(input_cols) output_size = len(output_cols)
output_size = len(output_cols) model=Model_xPosition()
model=Model_xPosition() epochs = 1000
epochs = 2000 lr = 1e-5
lr = 1e-5 learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
learning_proccess = fit(epochs, lr, model, train_loader, val_loader) def predict_single(input, target, model):
def predict_single(input, target, model):
inputs = input.unsqueeze(0) inputs = input.unsqueeze(0)
predictions = model(inputs) predictions = model(inputs)
prediction = predictions[0].detach() prediction = predictions[0].detach()
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n" return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
@ex.automain
def my_main(epochs):
for i in random.sample(range(0, len(val_ds)), 10): for i in random.sample(range(0, len(val_ds)), 10):
input_, target = val_ds[i] input_, target = val_ds[i]
print(predict_single(input_, target, model),end="") print(predict_single(input_, target, model),end="")