diff --git a/predict_registry.py b/predict_registry.py index 3561d57..adc2467 100644 --- a/predict_registry.py +++ b/predict_registry.py @@ -9,7 +9,7 @@ with open(f'{registry_path}/input_example.json') as f: input_example_data = json.load(f) print(input_example_data) -input_example = np.array(input_example_data['inputs']) # +input_example = np.array(input_example_data['inputs']).reshape(-1, 8) # print(f'Input example: {input_example}') print(f'Model prediction: {model.predict(input_example)}')