diff --git a/ml_pytroch_sacred.py b/ml_pytroch_sacred.py index a61dcef..2373faa 100644 --- a/ml_pytroch_sacred.py +++ b/ml_pytroch_sacred.py @@ -100,21 +100,22 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): history.append(result) return history -@ex.automain -def my_main(epochs): - input_size = len(input_cols) - output_size = len(output_cols) - model=Model_xPosition() - epochs = 2000 - lr = 1e-5 - learning_proccess = fit(epochs, lr, model, train_loader, val_loader) - def predict_single(input, target, model): - inputs = input.unsqueeze(0) - predictions = model(inputs) - prediction = predictions[0].detach() - return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n" +input_size = len(input_cols) +output_size = len(output_cols) +model=Model_xPosition() +epochs = 1000 +lr = 1e-5 +learning_proccess = fit(epochs, lr, model, train_loader, val_loader) +def predict_single(input, target, model): + inputs = input.unsqueeze(0) + predictions = model(inputs) + prediction = predictions[0].detach() + + return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n" +@ex.automain +def my_main(epochs): for i in random.sample(range(0, len(val_ds)), 10): input_, target = val_ds[i] print(predict_single(input_, target, model),end="")