Zaktualizuj 'ml_pytroch_sacred.py'

This commit is contained in:
Sebastian Wałęsa 2022-05-08 19:58:27 +02:00
parent 6ec74ae14b
commit 8d11e06d9f

View File

@ -100,21 +100,22 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
history.append(result) history.append(result)
return history return history
input_size = len(input_cols)
output_size = len(output_cols)
model=Model_xPosition()
epochs = 1000
lr = 1e-5
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
def predict_single(input, target, model):
inputs = input.unsqueeze(0)
predictions = model(inputs)
prediction = predictions[0].detach()
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
@ex.automain @ex.automain
def my_main(epochs): def my_main(epochs):
input_size = len(input_cols)
output_size = len(output_cols)
model=Model_xPosition()
epochs = 2000
lr = 1e-5
learning_proccess = fit(epochs, lr, model, train_loader, val_loader)
def predict_single(input, target, model):
inputs = input.unsqueeze(0)
predictions = model(inputs)
prediction = predictions[0].detach()
return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n"
for i in random.sample(range(0, len(val_ds)), 10): for i in random.sample(range(0, len(val_ds)), 10):
input_, target = val_ds[i] input_, target = val_ds[i]
print(predict_single(input_, target, model),end="") print(predict_single(input_, target, model),end="")